Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modularization fixes for fast_callable interpreters #35774

Merged
merged 11 commits into from
Jul 1, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 63 additions & 31 deletions src/sage/ext/fast_callable.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -465,41 +465,73 @@ def fast_callable(x, domain=None, vars=None,
etb = ExpressionTreeBuilder(vars=vars, domain=domain)
et = x._fast_callable_(etb)

if isinstance(domain, sage.rings.abc.RealField):
from sage.ext.interpreters.wrapper_rr import Wrapper_rr as builder
str = InstructionStream(sage.ext.interpreters.wrapper_rr.metadata,
len(vars),
domain)

elif isinstance(domain, sage.rings.abc.ComplexField):
from sage.ext.interpreters.wrapper_cc import Wrapper_cc as builder
str = InstructionStream(sage.ext.interpreters.wrapper_cc.metadata,
len(vars),
domain)

elif isinstance(domain, sage.rings.abc.RealDoubleField) or domain is float:
from sage.ext.interpreters.wrapper_rdf import Wrapper_rdf as builder
str = InstructionStream(sage.ext.interpreters.wrapper_rdf.metadata,
len(vars),
domain)
elif isinstance(domain, sage.rings.abc.ComplexDoubleField):
from sage.ext.interpreters.wrapper_cdf import Wrapper_cdf as builder
str = InstructionStream(sage.ext.interpreters.wrapper_cdf.metadata,
len(vars),
domain)
elif domain is None:
from sage.ext.interpreters.wrapper_py import Wrapper_py as builder
str = InstructionStream(sage.ext.interpreters.wrapper_py.metadata,
len(vars))
else:
from sage.ext.interpreters.wrapper_el import Wrapper_el as builder
str = InstructionStream(sage.ext.interpreters.wrapper_el.metadata,
len(vars),
domain)
builder, str = _builder_and_stream(vars=vars, domain=domain)

generate_code(et, str)
str.instr('return')
return builder(str.get_current())


def _builder_and_stream(vars, domain):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see what's going on from the diff, but if someone comes along a year from now, this docstring isn't really going to explain what the function is doing and why. Can you expand it a little?

Also, if possible, can you test the import-failure code path?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just mechanical refactoring of code that was nested a bit too deep. Let me see if I remember what it is doing, after a mere 4 days, not a year...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you test the import-failure code path?

done in afd3be6

r"""
Return a builder and a stream.

INPUT:

- ``vars`` -- a sequence of variable names

- ``domain`` -- a Sage parent or Python type or ``None``; if non-``None``,
all arithmetic is done in that domain

OUTPUT: A :class:`Wrapper`, an class:`InstructionStream`

EXAMPLES:

sage: from sage.ext.fast_callable import _builder_and_stream
sage: _builder_and_stream(["x", "y"], ZZ)
(<class 'sage.ext.interpreters.wrapper_el.Wrapper_el'>,
<sage.ext.fast_callable.InstructionStream object at 0x...>)
"""
if isinstance(domain, sage.rings.abc.RealField):
try:
from sage.ext.interpreters.wrapper_rr import metadata, Wrapper_rr as builder
except ImportError:
pass
else:
return builder, InstructionStream(metadata, len(vars), domain)

if isinstance(domain, sage.rings.abc.ComplexField):
try:
from sage.ext.interpreters.wrapper_cc import metadata, Wrapper_cc as builder
except ImportError:
pass
else:
return builder, InstructionStream(metadata, len(vars), domain)

if isinstance(domain, sage.rings.abc.RealDoubleField) or domain is float:
try:
from sage.ext.interpreters.wrapper_rdf import metadata, Wrapper_rdf as builder
except ImportError:
pass
else:
return builder, InstructionStream(metadata, len(vars), domain)

if isinstance(domain, sage.rings.abc.ComplexDoubleField):
try:
from sage.ext.interpreters.wrapper_cdf import metadata, Wrapper_cdf as builder
except ImportError:
pass
else:
return builder, InstructionStream(metadata, len(vars), domain)

if domain is None:
from sage.ext.interpreters.wrapper_py import metadata, Wrapper_py as builder
return builder, InstructionStream(metadata, len(vars))

from sage.ext.interpreters.wrapper_el import metadata, Wrapper_el as builder
return builder, InstructionStream(metadata, len(vars), domain)


def function_name(fn):
r"""
Given a function, return a string giving a name for the function.
Expand Down
36 changes: 18 additions & 18 deletions src/sage_setup/autogen/interpreters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,25 +119,10 @@
from .instructions import *
from .memory import *
from .specs.base import *
from .specs.cdf import *
from .specs.element import *
from .specs.python import *
from .specs.rdf import *
from .specs.rr import *
from .specs.cc import *
from .storage import *
from .utils import *


# Gather up a list of all interpreter classes imported into this module
# A better way might be to recursively iterate InterpreterSpec.__subclasses__
# or to use a registry, but this is fine for now.
_INTERPRETERS = sorted(filter(lambda c: (isinstance(c, type) and
issubclass(c, InterpreterSpec) and
c.name),
globals().values()),
key=lambda c: c.name)

# Tuple of (filename_root, extension, method) where filename_root is the
# root of the filename to be joined with "_<interpreter_name>".ext and
# method is the name of a get_ method on InterpreterGenerator that returns
Expand All @@ -157,6 +142,7 @@ def build_interp(interp_spec, dir):
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.rdf import RDFInterpreter
sage: testdir = tmp_dir()
sage: rdf_interp = RDFInterpreter()
sage: build_interp(rdf_interp, testdir)
Expand All @@ -174,7 +160,7 @@ def build_interp(interp_spec, dir):
write_if_changed(path, method())


def rebuild(dirname, force=False):
def rebuild(dirname, force=False, interpreters=None, distribution=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're adding arguments, can you add the missing INPUT block that explains what they do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in 12c04b8

r"""
Check whether the interpreter and wrapper sources have been written
since the last time this module was changed. If not, write them.
Expand All @@ -194,6 +180,20 @@ def rebuild(dirname, force=False):
# you run it).
print("Building interpreters for fast_callable")

if interpreters is None:
interpreters = ['CDF', 'Element', 'Python', 'RDF', 'RR', 'CC']

from importlib import import_module

_INTERPRETERS = [getattr(import_module('sage_setup.autogen.interpreters.specs.' + interpreter.lower()),
interpreter + 'Interpreter')
for interpreter in interpreters]

if distribution is None:
all_py = 'all.py'
else:
all_py = f'all__{distribution.replace("-", "_")}.py'

try:
os.makedirs(dirname)
except OSError:
Expand All @@ -213,7 +213,7 @@ class NeedToRebuild(Exception):
try:
if force:
raise NeedToRebuild("-> Force rebuilding interpreters")
gen_file = os.path.join(dirname, 'all.py')
gen_file = os.path.join(dirname, all_py)
if not os.path.isfile(gen_file):
raise NeedToRebuild("-> First build of interpreters")

Expand All @@ -235,5 +235,5 @@ class NeedToRebuild(Exception):
for interp in _INTERPRETERS:
build_interp(interp(), dirname)

with open(os.path.join(dirname, 'all.py'), 'w') as f:
with open(os.path.join(dirname, all_py), 'w') as f:
f.write("# " + AUTOGEN_WARN)
15 changes: 15 additions & 0 deletions src/sage_setup/autogen/interpreters/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self, spec):
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.rdf import RDFInterpreter
sage: interp = RDFInterpreter()
sage: gen = InterpreterGenerator(interp)
sage: gen._spec is interp
Expand Down Expand Up @@ -72,6 +73,7 @@ def gen_code(self, instr_desc, write):
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.rdf import RDFInterpreter
sage: interp = RDFInterpreter()
sage: gen = InterpreterGenerator(interp)
sage: from io import StringIO
Expand Down Expand Up @@ -218,6 +220,7 @@ def func_header(self, cython=False):
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.element import ElementInterpreter
sage: interp = ElementInterpreter()
sage: gen = InterpreterGenerator(interp)
sage: print(gen.func_header())
Expand Down Expand Up @@ -260,6 +263,7 @@ def write_interpreter(self, write):
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.rdf import RDFInterpreter
sage: interp = RDFInterpreter()
sage: gen = InterpreterGenerator(interp)
sage: from io import StringIO
Expand Down Expand Up @@ -307,6 +311,7 @@ def write_wrapper(self, write):
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.rdf import RDFInterpreter
sage: interp = RDFInterpreter()
sage: gen = InterpreterGenerator(interp)
sage: from io import StringIO
Expand Down Expand Up @@ -476,6 +481,7 @@ def write_pxd(self, write):
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.rdf import RDFInterpreter
sage: interp = RDFInterpreter()
sage: gen = InterpreterGenerator(interp)
sage: from io import StringIO
Expand Down Expand Up @@ -527,6 +533,9 @@ def get_interpreter(self):
First we get the InterpreterSpec for several interpreters::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.rdf import RDFInterpreter
sage: from sage_setup.autogen.interpreters.specs.rr import RRInterpreter
sage: from sage_setup.autogen.interpreters.specs.element import ElementInterpreter
sage: rdf_spec = RDFInterpreter()
sage: rr_spec = RRInterpreter()
sage: el_spec = ElementInterpreter()
Expand Down Expand Up @@ -649,6 +658,9 @@ def get_wrapper(self):
First we get the InterpreterSpec for several interpreters::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.rdf import RDFInterpreter
sage: from sage_setup.autogen.interpreters.specs.rr import RRInterpreter
sage: from sage_setup.autogen.interpreters.specs.element import ElementInterpreter
sage: rdf_spec = RDFInterpreter()
sage: rr_spec = RRInterpreter()
sage: el_spec = ElementInterpreter()
Expand Down Expand Up @@ -972,6 +984,9 @@ def get_pxd(self):
First we get the InterpreterSpec for several interpreters::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.rdf import RDFInterpreter
sage: from sage_setup.autogen.interpreters.specs.rr import RRInterpreter
sage: from sage_setup.autogen.interpreters.specs.element import ElementInterpreter
sage: rdf_spec = RDFInterpreter()
sage: rr_spec = RRInterpreter()
sage: el_spec = ElementInterpreter()
Expand Down
11 changes: 10 additions & 1 deletion src/sage_setup/autogen/interpreters/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ class InstrSpec(object):
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.rdf import RDFInterpreter
sage: pg = RDFInterpreter().pg
sage: InstrSpec('add', pg('SS','S'), code='o0 = i0+i1;')
add: SS->S = 'o0 = i0+i1;'
Expand All @@ -213,7 +214,7 @@ def __init__(self, name, io, code=None, uses_error_handler=False,
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *

sage: from sage_setup.autogen.interpreters.specs.rdf import RDFInterpreter
sage: pg = RDFInterpreter().pg
sage: InstrSpec('add', pg('SS','S'), code='o0 = i0+i1;')
add: SS->S = 'o0 = i0+i1;'
Expand Down Expand Up @@ -288,6 +289,7 @@ def __repr__(self):
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.rdf import RDFInterpreter
sage: pg = RDFInterpreter().pg
sage: InstrSpec('add', pg('SS','S'), code='o0 = i0+i1;')
add: SS->S = 'o0 = i0+i1;'
Expand All @@ -310,6 +312,7 @@ def instr_infix(name, io, op):
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.rdf import RDFInterpreter
sage: pg = RDFInterpreter().pg
sage: instr_infix('mul', pg('SS', 'S'), '*')
mul: SS->S = 'o0 = i0 * i1;'
Expand All @@ -325,6 +328,7 @@ def instr_funcall_2args(name, io, op):
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.rdf import RDFInterpreter
sage: pg = RDFInterpreter().pg
sage: instr_funcall_2args('atan2', pg('SS', 'S'), 'atan2')
atan2: SS->S = 'o0 = atan2(i0, i1);'
Expand All @@ -340,6 +344,7 @@ def instr_unary(name, io, op):
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.rdf import RDFInterpreter
sage: pg = RDFInterpreter().pg
sage: instr_unary('sin', pg('S','S'), 'sin(i0)')
sin: S->S = 'o0 = sin(i0);'
Expand All @@ -357,6 +362,7 @@ def instr_funcall_2args_mpfr(name, io, op):
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.rr import RRInterpreter
sage: pg = RRInterpreter().pg
sage: instr_funcall_2args_mpfr('add', pg('SS','S'), 'mpfr_add')
add: SS->S = 'mpfr_add(o0, i0, i1, MPFR_RNDN);'
Expand All @@ -372,6 +378,7 @@ def instr_funcall_1arg_mpfr(name, io, op):
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.rr import RRInterpreter
sage: pg = RRInterpreter().pg
sage: instr_funcall_1arg_mpfr('exp', pg('S','S'), 'mpfr_exp')
exp: S->S = 'mpfr_exp(o0, i0, MPFR_RNDN);'
Expand All @@ -386,6 +393,7 @@ def instr_funcall_2args_mpc(name, io, op):
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.cc import CCInterpreter
sage: pg = CCInterpreter().pg
sage: instr_funcall_2args_mpc('add', pg('SS','S'), 'mpc_add')
add: SS->S = 'mpc_add(o0, i0, i1, MPC_RNDNN);'
Expand All @@ -400,6 +408,7 @@ def instr_funcall_1arg_mpc(name, io, op):
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.cc import CCInterpreter
sage: pg = CCInterpreter().pg
sage: instr_funcall_1arg_mpc('exp', pg('S','S'), 'mpc_exp')
exp: S->S = 'mpc_exp(o0, i0, MPC_RNDNN);'
Expand Down
1 change: 1 addition & 0 deletions src/sage_setup/autogen/interpreters/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class using this memory chunk, to allocate local variables.
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.rr import *
sage: mc = MemoryChunkRRRetval('retval', ty_mpfr)
sage: mc.declare_call_locals()
' cdef RealNumber retval = (self.domain)()\n'
Expand Down
6 changes: 6 additions & 0 deletions src/sage_setup/autogen/interpreters/specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def __init__(self):
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.rdf import RDFInterpreter
sage: from sage_setup.autogen.interpreters.specs.rr import RRInterpreter
sage: interp = RDFInterpreter()
sage: interp.c_header
'#include <gsl/gsl_math.h>'
Expand Down Expand Up @@ -84,6 +86,7 @@ def _set_opcodes(self):
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.rdf import RDFInterpreter
sage: interp = RDFInterpreter()
sage: interp.instr_descs[5].opcode
5
Expand Down Expand Up @@ -128,6 +131,9 @@ def __init__(self, type, mc_retval=None):
EXAMPLES::

sage: from sage_setup.autogen.interpreters import *
sage: from sage_setup.autogen.interpreters.specs.rdf import RDFInterpreter
sage: from sage_setup.autogen.interpreters.specs.rr import RRInterpreter
sage: from sage_setup.autogen.interpreters.specs.element import ElementInterpreter
sage: rdf = RDFInterpreter()
sage: rr = RRInterpreter()
sage: el = ElementInterpreter()
Expand Down
Loading