Skip to content

Commit

Permalink
working static types for all files but geometry.py, index.py
Browse files Browse the repository at this point in the history
  • Loading branch information
GNiendorf committed Sep 7, 2024
1 parent d842650 commit 0f46129
Show file tree
Hide file tree
Showing 13 changed files with 242 additions and 139 deletions.
2 changes: 1 addition & 1 deletion tests/test_hyperbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@
def test_rms_hyperbolic():
geo = [back_lens, lens, stop]
ray_group = tp.ray_plane(geo, [0., 0., 0.], 1.1, d=[0.,0.,1.], nrays=100)
rms = tp.spotdiagram(geo, ray_group, optimizer=True)
rms = tp.spot_rms(geo, ray_group)
assert rms == 0.

2 changes: 1 addition & 1 deletion tests/test_parabolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@
def test_rms_parabolic():
geo = [mirror, stop]
ray_group = tp.ray_plane(geo, [0., 0., -1.5], 1.1, d=[0.,0.,1.], nrays=100)
rms = tp.spotdiagram(geo, ray_group, optimizer=True)
rms = tp.spot_rms(geo, ray_group)
assert rms == 0.
9 changes: 4 additions & 5 deletions tracepy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
from .geometry import geometry
from .optimize import optimize
from .geoplot import plotxz, plotyz, plot2d
from .optplot import spotdiagram, plotobject, rayaberration
from .iotables import *
from .transforms import *
from .raygroup import *
from .index import *
from .optplot import spotdiagram, plotobject, rayaberration, spot_rms
from .iotables import save_optics
from .raygroup import ray_plane
from .index import cauchy_two_term, glass_index
from .utils import gen_rot
9 changes: 9 additions & 0 deletions tracepy/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Constants used by optimizer
SURV_CONST = 100 # Weight of failed propagation.
MAX_RMS = 999 # Maximum RMS penalty for trace error.

# Constants used for ray objects
MAX_INTERSECTION_ITERATIONS = 1e4 # Max iter before failed intersection search.
MAX_REFRACTION_ITERATIONS = 1e5 # Max iter before failed refraction.
INTERSECTION_CONVERGENCE_TOLERANCE = 1e-6 # Tolerance for intersection search.
REFRACTION_CONVERGENCE_TOLERANCE = 1e-15 # Tolerance for refraction.
24 changes: 9 additions & 15 deletions tracepy/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from .exceptions import NotOnSurfaceError
from .index import glass_index

from typing import Dict, List, Tuple

class geometry:
"""Class for the different surfaces in an optical system.
Expand Down Expand Up @@ -39,12 +41,12 @@ class geometry:
If c is 0 then the surface is planar.
name (optional): str
Name of the surface, used for optimization
R (generated): np.matrix((3,3))
R (generated): np.array((3,3))
Rotation matrix for the surface from rotation angles D.
"""

def __init__(self, params):
def __init__(self, params: Dict):
self.P = params['P']
self.D = np.array(params.get('D', [0., 0., 0.]))
self.action = params['action']
Expand All @@ -59,15 +61,7 @@ def __init__(self, params):
self.glass = glass_index(params.get('glass'))
self.check_params()

def __getitem__(self, item):
""" Return attribute of geometry. """
return getattr(self, item)

def __setitem__(self, item, value):
""" Set attribute of geometry. """
return setattr(self, item, value)

def check_params(self):
def check_params(self) -> None:
"""Check that required parameters are given and update needed parameters.
Summary
Expand Down Expand Up @@ -98,15 +92,15 @@ def check_params(self):
#Used for planes, does not affect calculations.
self.kappa = 1.

def get_surface(self, point):
def get_surface(self, point: np.ndarray) -> Tuple[float, List[float]]:
""" Returns the function and derivitive of a surface for a point. """
return self.conics(point)

def get_surface_plot(self, points):
def get_surface_plot(self, points: np.ndarray) -> np.ndarray:
""" Returns the function value for an array of points. """
return self.conics_plot(points)

def conics(self, point):
def conics(self, point: np.ndarray) -> Tuple[float, List[float]]:
"""Returns function value and derivitive list for conics and sphere surfaces.
Note
Expand Down Expand Up @@ -144,7 +138,7 @@ def conics(self, point):
derivitive = [-X*E, -Y*E, 1.]
return function, derivitive

def conics_plot(self, point):
def conics_plot(self, point: np.ndarray) -> np.ndarray:
"""Returns Z values for an array of points for plotting conics.
Parameters
Expand Down
50 changes: 31 additions & 19 deletions tracepy/geoplot.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import matplotlib.pyplot as plt
import numpy as np
from numpy import pi
import matplotlib.pyplot as plt

from .ray import ray
from .geometry import geometry
from .transforms import lab_frame
from .transforms import lab_frame_points

def _gen_points(surfaces):
from typing import List, Dict, Optional

def _gen_points(surfaces: List[geometry]) -> np.ndarray:
"""Generates the mesh points for each surface in the obj frame.
Parameters
----------
surface : geometry object
Surface whos reference frame to generate the points in.
surfaces : list of geometry objects
List of surfaces whos reference frames to generate the points in.
Returns
-------
Expand All @@ -37,12 +39,14 @@ def _gen_points(surfaces):
surfpoints.append(meshpoints)
return np.array(surfpoints)

def _plot_rays(rays, axes, pltparams):
def _plot_rays(rays: np.ndarray,
axes: List[int],
pltparams: Dict) -> None:
"""Plots 2d ray history points. Takes list axes to specify axes (0, 1, 2) to plot.
Parameters
----------
rays : list of ray objects
rays : np.array of ray objects
Rays that are going to be plotted.
axes : list of length 2 with integers from range [0,2]
Axes (X, Y, Z) to plot from ray points.
Expand All @@ -61,7 +65,7 @@ def _plot_rays(rays, axes, pltparams):
#Plot direction of ray after stop.
plt.plot([G_p, G_p+I_p],[F_p, F_p+H_p], **pltparams)

def _clip_lens(surfaces, surfpoints, idx):
def _clip_lens(surfaces: List[geometry], surfpoints: np.ndarray, idx: int) -> np.ndarray:
"""Clips points ouside of a lens intersection point.
Parameters
Expand Down Expand Up @@ -89,7 +93,7 @@ def _clip_lens(surfaces, surfpoints, idx):
surfpoints[idx+1][:,2][clipped_idx] = np.nan
return surfpoints

def _plot_surfaces(geo_params, axes):
def _plot_surfaces(geo_params: List[Dict], axes: List[int]) -> None:
"""Plots 2d surface cross sections. Takes list axes to specify axes (0, 1, 2) to plot.
Note
Expand All @@ -115,15 +119,15 @@ def _plot_surfaces(geo_params, axes):
if lens_condition:
surfpoints = _clip_lens(surfaces, surfpoints, idx)
with np.errstate(invalid='ignore'):
if np.any(np.mod(surf.D/pi, 1) != 0) and surf.c == 0 and surf.diam == 0:
if np.any(np.mod(surf.D/np.pi, 1) != 0) and surf.c == 0 and surf.diam == 0:
#Find cross section points.
cross_idx = abs(surfpoints[idx][:,axes[1]]) == 0
else:
#Find cross section points.
cross_idx = abs(surfpoints[idx][:,1-axes[0]]) == 0
cross_points = surfpoints[idx][cross_idx]
#Transform to lab frame.
points = lab_frame(surf.R, surf, cross_points)
points = lab_frame_points(surf.R, surf, cross_points)
F, G = points[:,axes[0]], points[:,axes[1]]
#Connect the surfaces in a lens
if surfaces[idx].action == surfaces[idx-1].action == 'refraction' and start is not None:
Expand All @@ -145,14 +149,17 @@ def _plot_surfaces(geo_params, axes):
end = np.array([F[-1], G[-1]])
plt.plot(G, F, 'k')

def plotxz(geo_params, rays, pltparams={'c': 'red', 'alpha': 0.3 }, both=None):
def plotxz(geo_params: List[Dict],
ray_list: List[ray],
pltparams: Dict = {'c': 'red', 'alpha': 0.3 },
both: Optional[bool] = None) -> None:
"""Plots the xz coordinates of all rays and surface cross sections.
Parameters
----------
geo_params : list of dictionaries
Surfaces in propagation order to plot.
rays : list of ray objects
ray_list : list of ray objects
Rays that are going to be plotted.
pltparams : dictionary
Plot characteristics of rays such as colors and alpha.
Expand All @@ -161,7 +168,7 @@ def plotxz(geo_params, rays, pltparams={'c': 'red', 'alpha': 0.3 }, both=None):
"""

rays = np.array([rayiter for rayiter in rays if rayiter.active != 0])
rays = np.array([rayiter for rayiter in ray_list if rayiter.active])
#Override 1,1,1 subplot if displaying side-by-side.
if both is None:
#Keep aspect ratio equal.
Expand All @@ -171,14 +178,17 @@ def plotxz(geo_params, rays, pltparams={'c': 'red', 'alpha': 0.3 }, both=None):
plt.xlabel("Z")
plt.ylabel("X")

def plotyz(geo_params, rays, pltparams={'c': 'red', 'alpha': 0.3 }, both=None):
def plotyz(geo_params: List[Dict],
ray_list: List[ray],
pltparams: Dict = {'c': 'red', 'alpha': 0.3 },
both: Optional[bool] = None) -> None:
"""Plots the yz coordinates of all rays and surface cross sections.
Parameters
----------
geo_params : list of dictionaries
Surfaces in propagation order to plot.
rays : list of ray objects
ray_list : list of ray objects
Rays that are going to be plotted.
pltparams : dictionary
Plot characteristics of rays such as colors and alpha.
Expand All @@ -187,7 +197,7 @@ def plotyz(geo_params, rays, pltparams={'c': 'red', 'alpha': 0.3 }, both=None):
"""

rays = np.array([rayiter for rayiter in rays if rayiter.active != 0])
rays = np.array([rayiter for rayiter in ray_list if rayiter.active])
#Override 1,1,1 subplot if displaying side-by-side.
if both is None:
#Keep aspect ratio equal.
Expand All @@ -197,7 +207,9 @@ def plotyz(geo_params, rays, pltparams={'c': 'red', 'alpha': 0.3 }, both=None):
plt.xlabel("Z")
plt.ylabel("Y")

def plot2d(geo_params, rays, pltparams={'c': 'red', 'alpha': 0.3 }):
def plot2d(geo_params: List[Dict],
rays: List[ray],
pltparams: Dict = {'c': 'red', 'alpha': 0.3 }) -> None:
"""Plots both xz and yz side-by-side.
Parameters
Expand Down
26 changes: 20 additions & 6 deletions tracepy/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,29 @@
import numpy as np
from scipy.optimize import curve_fit

def cauchy_two_term(x,B,C):
'''
This is a simple two term cauchy for index of refraction as function of wavelength.
def cauchy_two_term(x: float, B: float, C: float) -> float:
"""
This is a simple two-term Cauchy equation for the index of refraction as a function of wavelength.
https://en.wikipedia.org/wiki/Cauchy%27s_equation
It is used to create a function for refractive index database entries when only tabulated data is available.
'''
return B + (C/(x**2))
Parameters
----------
x : float
Wavelength (in micrometers).
B : float
First term of the Cauchy equation.
C : float
Second term of the Cauchy equation.
Returns
-------
float
Refractive index at the given wavelength.
"""
return B + (C / (x ** 2))

def glass_index (glass):
def glass_index(glass):
'''
Given a glass name this function will return a function that takes wavelength in microns and returns index of refraction.
- All values were taken from refractiveindexinfo's database of different optical glasses.
Expand Down
4 changes: 3 additions & 1 deletion tracepy/iotables.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pandas as pd

def save_optics(geo_params, filename):
# TODO: Rewrite save_optics to iterate over geo_params to handle edge
# cases, and convert to dataframe at end before exporting to csv.
def save_optics(geo_params, filename: str):
"""Save geometry to an optics table in csv format.
Note
Expand Down
28 changes: 18 additions & 10 deletions tracepy/optimize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import numpy as np
from scipy.optimize import least_squares

from .optplot import spotdiagram
from .optplot import spot_rms
from .raygroup import ray_plane
from .constants import SURV_CONST, MAX_RMS
from .exceptions import TraceError

def update_geometry(inputs, geoparams, vary_dicts):
from typing import List, Union, Dict, Optional

def update_geometry(inputs: List[Union[float, int]],
geoparams: List[Dict],
vary_dicts: List[Dict]) -> List[Dict]:
"""Return the geometry requested by the optimization algorithm.
Parameters
Expand Down Expand Up @@ -35,7 +40,9 @@ def update_geometry(inputs, geoparams, vary_dicts):
vary_idxs += 1
return geoparams

def get_rms(inputs, geoparams, vary_dicts):
def get_rms(inputs: List[Union[float, int]],
geoparams: List[Dict],
vary_dicts: List[Dict]) -> float:
"""Return the rms of an updated geometry.
Note
Expand All @@ -62,16 +69,17 @@ def get_rms(inputs, geoparams, vary_dicts):

params_iter = update_geometry(inputs, geoparams, vary_dicts)
raygroup_iter = ray_plane(params_iter, [0., 0., 0.], 1.1, d=[0.,0.,1.], nrays=50)
ratio_surv = np.sum([1 for ray in raygroup_iter if ray.active != 0])/len(raygroup_iter)
ratio_surv = np.sum([1 for ray in raygroup_iter if ray.active])/len(raygroup_iter)
try:
rms = spotdiagram(params_iter, raygroup_iter, optimizer=True)
rms = spot_rms(params_iter, raygroup_iter)
except TraceError:
rms = 999.
#Weight of failed propagation.
surv_const = 100
return rms + (1-ratio_surv)*surv_const
rms = MAX_RMS
return rms + (1 - ratio_surv) * SURV_CONST

def optimize(geoparams, vary_dicts, typeof='least_squares', max_iter=None):
def optimize(geoparams:List[Dict],
vary_dicts: List[Dict],
typeof: str = 'least_squares',
max_iter: Optional[int] = None) -> List[Dict]:
"""Optimize a given geometry for a given varylist and return the new geometry.
Parameters
Expand Down
Loading

0 comments on commit 0f46129

Please sign in to comment.