Skip to content

Commit

Permalink
Merge pull request #4 from EQuS/2/remove-iso
Browse files Browse the repository at this point in the history
2/remove iso
  • Loading branch information
Phionx authored Mar 1, 2024
2 parents 4e69b7e + 2b8ff50 commit 600513b
Show file tree
Hide file tree
Showing 6 changed files with 2,180 additions and 25 deletions.
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,28 @@ Please use `pip install -e '.[dev, docs]'` if you are a `zsh` user.

Installing the package in the usual non-editable mode would require a developer to upgrade their pip installation (i.e. run `pip install --upgrade .`) every time they update the package source code.

#### Install with GPU support (Linux)

For linux users who wish to enable Nvidia GPU support, here are some steps ([ref](https://jax.readthedocs.io/en/latest/installation.html#nvidia-gpu)):

1. Make sure you NVIDIA drivers by running:
`cat /proc/driver/nvidia/version` or `sudo ubuntu-drivers list`
2. If your driver version is >= 525.60.13 then run:
`pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html` otherwise, use `cuda11_pip`
3. Test that GPU support is enabled:
4. Enjoy!

***Notes:***
If you receive this error:
```
2024-02-27 14:10:45.052355: W external/xla/xla/service/gpu/nvptx_compiler.cc:742] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
```

Then, you should update your NVIDIA driver by running:
```
conda install cuda -c nvidia
```

## Documentation

Documentation should be viewable here: [https://github.com/pages/EQuS/jaxquantum/](https://github.com/pages/EQuS/jaxquantum/)
Expand Down
47 changes: 32 additions & 15 deletions jaxquantum/quantum/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
Common jax <-> qutip-inspired functions
"""

from jax.config import config
from jax import config, Array
from jax.nn import one_hot

import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np
import qutip as qt
from qutip import Qobj

from jaxquantum.utils.utils import is_1d

Expand Down Expand Up @@ -42,11 +42,11 @@ def jax2qt(jax_obj, dims=None):
Returns:
QuTiP state.
"""
if isinstance(jax_obj, qt.Qobj) or jax_obj is None:
if isinstance(jax_obj, Qobj) or jax_obj is None:
return jax_obj
if dims is not None:
dims = np.array(dims).astype(int).tolist()
return qt.Qobj(np.array(jax_obj), dims=dims)
return Qobj(np.array(jax_obj), dims=dims)


# QuTiP alternatives in JAX (some are a WIP)
Expand All @@ -69,6 +69,18 @@ def unit(rho: jnp.ndarray, use_density_matrix=False):
return rho / jnp.linalg.norm(rho)


def ket(vec: Array) -> Array:
"""Turns a vector array into a ket.
Args:
vec: vector
Returns:
ket
"""
return vec.reshape(vec.shape[0], 1)


def dag(op: jnp.ndarray) -> jnp.ndarray:
"""Conjugate transpose.
Expand All @@ -78,8 +90,10 @@ def dag(op: jnp.ndarray) -> jnp.ndarray:
Returns:
conjugate transpose of op
"""
op = op.reshape(op.shape[0], -1) # adds dimension to 1D array if needed
return jnp.conj(op).T


def batch_dag(op: jnp.ndarray) -> jnp.ndarray:
"""Conjugate transpose.
Expand All @@ -89,7 +103,9 @@ def batch_dag(op: jnp.ndarray) -> jnp.ndarray:
Returns:
conjugate of op, and transposes last two axes
"""
return jnp.moveaxis(jnp.conj(op), -1, -2) # transposes last two axes, good for batching
return jnp.moveaxis(
jnp.conj(op), -1, -2
) # transposes last two axes, good for batching


def ket2dm(ket: jnp.ndarray) -> jnp.ndarray:
Expand Down Expand Up @@ -199,18 +215,19 @@ def num(N) -> jnp.ndarray:
return jnp.diag(jnp.arange(N))


def coherent(N, alpha) -> jnp.ndarray:
def coherent(N, α) -> jnp.ndarray:
"""Coherent state.
TODO: add trimming!
Args:
N: Hilbert Space Size
alpha: coherent state amplitude
N: Hilbert Space Size.
α: coherent state amplitude.
Return:
coherent state |alpha>
Coherent state |α⟩.
"""
# TODO: replace with JAX implementation
return qt2jax(qt.coherent(int(N), complex(alpha)))
return displace(N, α) @ basis(N, 0)


def identity(*args, **kwargs) -> jnp.ndarray:
Expand All @@ -223,17 +240,17 @@ def identity(*args, **kwargs) -> jnp.ndarray:


def displace(N, α) -> jnp.ndarray:
"""Displace operator
"""Displacement operator
Args:
N: Hilbert Space Size
α: displacement
α: Phase space displacement
Returns:
Displace operator D(α)
"""
# TODO: replace with JAX implementation
return qt2jax(qt.displace(int(N), float(α)))
a = destroy(N)
return expm(α * dag(a) - jnp.conj(α) * a)


def ptrace(rho, indx, dims):
Expand Down
164 changes: 155 additions & 9 deletions jaxquantum/quantum/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@
def spre(op: jnp.ndarray) -> Callable[[jnp.ndarray], jnp.ndarray]:
"""Superoperator generator.
Args:
op: operator to be turned into a superoperator
Returns:
superoperator function
"""
op_dag = op.conj().T
return lambda rho: 0.5 * (
2 * op @ rho @ op_dag - rho @ op_dag @ op - op_dag @ op @ rho
)


def spre_iso(op: jnp.ndarray) -> Callable[[jnp.ndarray], jnp.ndarray]:
"""Superoperator generator.
Args:
op: operator to be turned into a superoperator
Expand All @@ -41,7 +56,77 @@ def spre(op: jnp.ndarray) -> Callable[[jnp.ndarray], jnp.ndarray]:
static_argnums=(4,),
)
def mesolve(
p: jnp.ndarray,
ρ0: jnp.ndarray,
t_list: jnp.ndarray,
c_ops: Optional[List[jnp.ndarray]] = jnp.array([]),
H0: Optional[jnp.ndarray] = None,
Ht: Optional[Callable[[float], jnp.ndarray]] = None,
):
"""Quantum Master Equation solver.
Args:
ρ0: initial state, must be a density matrix. For statevector evolution, please use sesolve.
t_list: time list
c_ops: list of collapse operators
H0: time independent Hamiltonian. If H0 is not None, it will override Ht.
Ht: time dependent Hamiltonian function.
Returns:
list of states
"""

ρ0 = jnp.asarray(ρ0) + 0.0j
c_ops = jnp.asarray(c_ops) + 0.0j
H0 = jnp.asarray(H0) + 0.0j if H0 is not None else H0

def f(
t: float,
rho: jnp.ndarray,
args: jnp.ndarray,
):
H0_val = args[0]
c_ops_val = args[1]

if H0_val is not None:
H = H0_val # use H0 if given
else:
H = Ht(t) # type: ignore
H = H + 0.0j

rho_dot = -1j * (H @ rho - rho @ H)

for op in c_ops_val:
rho_dot += spre(op)(rho)

return rho_dot

term = ODETerm(f)
solver = Dopri5()
saveat = SaveAt(ts=t_list)
stepsize_controller = PIDController(rtol=1e-6, atol=1e-6)

sol = diffeqsolve(
term,
solver,
t0=t_list[0],
t1=t_list[-1],
dt0=t_list[1] - t_list[0],
y0=ρ0,
saveat=saveat,
stepsize_controller=stepsize_controller,
args=[H0, c_ops],
max_steps=16**5,
)

return sol.ys


@partial(
jit,
static_argnums=(4,),
)
def mesolve_iso(
ρ0: jnp.ndarray,
t_list: jnp.ndarray,
c_ops: Optional[List[jnp.ndarray]] = jnp.array([]),
H0: Optional[jnp.ndarray] = None,
Expand All @@ -50,7 +135,7 @@ def mesolve(
"""Quantum Master Equation solver.
Args:
p: initial state, must be a density matrix. For statevector evolution, please use sesolve.
ρ0: initial state, must be a density matrix. For statevector evolution, please use sesolve.
t_list: time list
c_ops: list of collapse operators
H0: time independent Hamiltonian. If H0 is not None, it will override Ht.
Expand All @@ -60,9 +145,9 @@ def mesolve(
list of states
"""

p = complex_to_real_iso_matrix(p + 0.0j)
c_ops = vmap(complex_to_real_iso_matrix)(c_ops + 0.0j)
H0 = None if H0 is None else complex_to_real_iso_matrix(H0 + 0.0j)
ρ0 = complex_to_real_iso_matrix(jnp.asarray(ρ0) + 0.0j)
c_ops = vmap(complex_to_real_iso_matrix)(jnp.asarray(c_ops) + 0.0j)
H0 = None if H0 is None else complex_to_real_iso_matrix(jnp.asarray(H0) + 0.0j)

def f(
t: float,
Expand Down Expand Up @@ -96,7 +181,7 @@ def f(
t0=t_list[0],
t1=t_list[-1],
dt0=t_list[1] - t_list[0],
y0=p,
y0=ρ0,
saveat=saveat,
stepsize_controller=stepsize_controller,
args=[H0, c_ops],
Expand All @@ -116,7 +201,68 @@ def sesolve(
H0: Optional[jnp.ndarray] = None,
Ht: Optional[Callable[[float], jnp.ndarray]] = None,
):
"""Schroedinger Equation solver.
"""Schrödinger Equation solver.
Args:
ψ: initial statevector
t_list: time list
H0: time independent Hamiltonian. If H0 is not None, it will override Ht.
Ht: time dependent Hamiltonian function.
Returns:
list of states
"""
ψ = jnp.asarray(ψ) + 0.0j
H0 = None if H0 is None else jnp.asarray(H0) + 0.0j

def f(
t: float,
ψₜ: jnp.ndarray,
args: jnp.ndarray,
):
H0_val = args[0]

if H0_val is not None:
H = H0_val # use H0 if given
else:
H = Ht(t) # type: ignore
# print("H", H.shape)
# print("psit", ψₜ.shape)
ψₜ_dot = -1j * (H @ ψₜ)

return ψₜ_dot

term = ODETerm(f)
solver = Dopri5()
saveat = SaveAt(ts=t_list)
stepsize_controller = PIDController(rtol=1e-6, atol=1e-6)

sol = diffeqsolve(
term,
solver,
t0=t_list[0],
t1=t_list[-1],
dt0=t_list[1] - t_list[0],
y0=ψ,
saveat=saveat,
stepsize_controller=stepsize_controller,
args=[H0],
)

return sol.ys


@partial(
jit,
static_argnums=(3,),
)
def sesolve_iso(
ψ: jnp.ndarray,
t_list: jnp.ndarray,
H0: Optional[jnp.ndarray] = None,
Ht: Optional[Callable[[float], jnp.ndarray]] = None,
):
"""Schrödinger Equation solver.
Args:
ψ: initial statevector
Expand All @@ -127,8 +273,8 @@ def sesolve(
Returns:
list of states
"""
ψ = complex_to_real_iso_vector(ψ + 0.0j)
H0 = None if H0 is None else complex_to_real_iso_matrix(H0 + 0.0j)
ψ = complex_to_real_iso_vector(jnp.asarray(ψ) + 0.0j)
H0 = None if H0 is None else complex_to_real_iso_matrix(jnp.asarray(H0) + 0.0j)

def f(
t: float,
Expand Down
2 changes: 1 addition & 1 deletion jaxquantum/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from jax import lax, jit
from jax import device_put
from jax.config import config
from jax import config
from jax._src.scipy.special import gammaln
import jax.numpy as jnp
import numpy as np
Expand Down
Loading

0 comments on commit 600513b

Please sign in to comment.