Skip to content

Commit

Permalink
Fix major logical error in CSE printing
Browse files Browse the repository at this point in the history
- use-counting was not applied recursively to the arguments of special functions; instead, the cpp_cse implementation simply fell back to printing these functions using GiNaC’s own C++ output formatting

- this means that any temporary CSE variables appearing in the argument of a special function might not be correctly use-counted, and therefore might not be output to the temporary pool. (This first showed up in Yvette’s model, which has a cosh with a fairly large argument for which CSE is nontrivial).

- now fixed by recursively applying use-counting, and also mapping GiNaC function names to their correct C++ versions. We also check for special functions that have not yet been implemented.
  • Loading branch information
ds283 committed Jul 7, 2017
1 parent 35c9e4d commit f7d8a28
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,48 +90,51 @@ void cse::parse(const GiNaC::ex& expr, std::string name)

timing_instrument instrument(timer);

// find iterator to entire expression;
// used to determine whether to create an anonymous temporary name, or used the supplied string "name"
// was a name supplied?
// (eg. "name" might be __InternalHsq or __InternalEps, for when we are trying to simplify
// expressions used in client code
const bool use_name = !name.empty();

// if a name was supplied, find iterator to entire expression
// technically this is --expr.postorder_end(), however these iterators
// are forward-directed only so we have to do a search
GiNaC::const_postorder_iterator last;

const bool use_name = name.length() > 0;
if(use_name) // avoid performing possibly expensive search if no name supplied
{
for(GiNaC::const_postorder_iterator t = expr.postorder_begin(); t != expr.postorder_end(); ++t)
{
last = t;
}
for(auto t = expr.postorder_begin(); t != expr.postorder_end(); ++t) last = t;
}

for(GiNaC::const_postorder_iterator t = expr.postorder_begin(); t != expr.postorder_end(); ++t)
for(auto t = expr.postorder_begin(); t != expr.postorder_end(); ++t)
{
// print this expression without use counting (false means that print will use get_symbol_without_use_count)
std::string e = this->print(*t, false);

// does this expression already exist in the lookup table?
symbol_table::iterator u = this->symbols.find(e);
auto u = this->symbols.find(e);

// if not, we should insert it
if(u == this->symbols.end())
if(u != this->symbols.end()) continue;

// get a name for this symbol. If it is the top-level expression and a name was supplied then we use it,
// otherwise we make a temporary
std::string nm;
if(use_name && t == last) nm = name;
else nm = this->make_symbol();

// perform insertion
auto result = this->symbols.emplace(e, cse_impl::symbol_record{e, nm});

// check whether insertion took place; failure could be due to aliasing
if(!result.second) throw cse_exception(name);

// if a name was supplied, we automatically deposit everything to the pool, because typically clients
// further up the stack will only get a GiNaC symbol corresponding to this name;
// they won't have an explicit expression to print which could cause these temporaries to be deposited
cse_impl::symbol_record& record = result.first->second;
if(use_name && !record.is_written())
{
std::string nm;
if(use_name && t == last) nm = name;
else nm = this->make_symbol();

// perform insertion
std::pair< symbol_table::iterator, bool > result = this->symbols.emplace(std::make_pair( e, cse_impl::symbol_record(e, nm) ));

// check whether insertion took place; failure could be due to aliasing
if(!result.second) throw cse_exception(name);

// if a name was supplied, we automatically deposit everything to the pool, because typically clients
// further up the stack will only get a GiNaC symbol corresponding to this name;
// they won't have an explicit expression to print which could cause these temporaries to be deposited
if(use_name && !result.first->second.is_written())
{
this->decls.push_back(std::make_pair(result.first->second.get_symbol(), result.first->second.get_target()));
result.first->second.set_written();
}
this->decls.emplace_back(record.get_symbol(), record.get_target());
record.set_written();
}
}
}
Expand All @@ -140,17 +143,19 @@ void cse::parse(const GiNaC::ex& expr, std::string name)
std::unique_ptr< std::list<std::string> >
cse::temporaries(const std::string& left, const std::string& mid, const std::string& right) const
{
std::unique_ptr< std::list<std::string> > rval = std::make_unique< std::list<std::string> >();
auto rval = std::make_unique< std::list<std::string> >();

// deposit each declaration into the output stream
for(const std::pair<std::string, std::string>& decl: this->decls)
for(const auto& decl: this->decls)
{
std::ostringstream out;

// replace LHS and RHS macros in the template
out << left << decl.first << mid << decl.second << right << '\n';

rval->push_back(out.str());
std::string temp = left;
temp.append(decl.first);
temp.append(mid);
temp.append(decl.second);
temp.append(right);
temp.append("\n");

rval->push_back(temp);
}

return(rval);
Expand All @@ -166,7 +171,7 @@ std::string cse::get_symbol_without_use_count(const GiNaC::ex& expr)
std::string e = this->print(expr, false);

// search for this expression in the lookup table
symbol_table::iterator t = this->symbols.find(e);
auto t = this->symbols.find(e);

// was it present? if not, return the plain expression
if(t == this->symbols.end()) return e;
Expand All @@ -185,11 +190,11 @@ std::string cse::get_symbol_with_use_count(const GiNaC::ex& expr)
std::string e = this->print(expr, true);

// search for this expression in the lookup table
symbol_table::iterator t = this->symbols.find(e);
auto t = this->symbols.find(e);

// was it present? if not, return the plain expression
if(t == this->symbols.end()) return e;

// if it was present, check whether this symbol has been written into the list
// of declarations

Expand All @@ -205,10 +210,14 @@ std::string cse::get_symbol_with_use_count(const GiNaC::ex& expr)

std::string cse::make_symbol()
{
std::ostringstream s;
std::string s{this->temporary_name_kernel};

s.append("_");
s.append(std::to_string(serial_number));
s.append("_");
s.append(std::to_string(symbol_counter));

s << this->temporary_name_kernel << "_" << serial_number << "_" << symbol_counter;
symbol_counter++;
++symbol_counter;

return(s.str());
return s;
}
82 changes: 61 additions & 21 deletions CppTransport/translator/backends/languages/cpp/cpp_cse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,42 +35,82 @@

namespace cpp
{

const std::map< std::string, std::string > func_convert
{
{"abs", "std::abs"},
{"sqrt", "std::sqrt"},
{"sin", "std::sin"},
{"cos", "std::cos"},
{"tan", "std::tan"},
{"asin", "std::asin"},
{"acos", "std::acos"},
{"atan", "std::atan"},
{"atan2", "std::atan2"},
{"sinh", "std::sinh"},
{"cosh", "std::cosh"},
{"tanh", "std::tanh"},
{"asinh", "std::asinh"},
{"acosh", "std::acosh"},
{"atanh", "std::atanh"},
{"exp", "std:exp"},
{"log", "std::log"},
{"pow", "std::pow"},
{"tgamma", "std::tgamma"},
{"lgamma", "std::lgamma"}
};


std::string cpp_cse::print(const GiNaC::ex& expr, bool use_count)
{
std::ostringstream out;
std::string name;

std::string name;

if(GiNaC::is_a<GiNaC::function>(expr)) name = GiNaC::ex_to<GiNaC::function>(expr).get_name();
else name = GiNaC::ex_to<GiNaC::basic>(expr).class_name();
else name = GiNaC::ex_to<GiNaC::basic>(expr).class_name();

if (name == "numeric") return this->printer.ginac(expr);
else if(name == "symbol") return this->printer.ginac(expr);
else if(name == "add") return this->print_operands(expr, "+", use_count);
else if(name == "mul") return this->print_operands(expr, "*", use_count);
else if(name == "power") return this->print_power(expr, use_count);

// not a standard operation, so assume it must be a special function
// look up it's C++ form in func_map, and then format its arguments,
// taking care to keep track of use counts

auto t = func_convert.find(name);
if(t == func_convert.end())
{
std::ostringstream msg;
msg << ERROR_UNIMPLEMENTED_MATHS_FUNCTION << " '" << name << "'";
throw cse_exception(msg.str());
}

if (name == "numeric") out << this->printer.ginac(expr);
else if(name == "symbol") out << this->printer.ginac(expr);
else if(name == "add") out << this->print_operands(expr, "+", use_count);
else if(name == "mul") out << this->print_operands(expr, "*", use_count);
else if(name == "power") out << this->print_power(expr, use_count);
else out << this->printer.ginac(expr);
std::string rval{t->second};
rval.append("(");
rval.append(this->print_operands(expr, ",", use_count));
rval.append(")");

return(out.str());
return rval;
}


std::string cpp_cse::print_operands(const GiNaC::ex& expr, std::string op, bool use_count)
{
std::ostringstream out;
std::string rval;

unsigned int c = 0;
for(GiNaC::const_iterator t = expr.begin(); t != expr.end(); ++t)
for(auto t = expr.begin(); t != expr.end(); ++t)
{
if(c > 0) out << op;

if(use_count) out << this->get_symbol_with_use_count(*t);
else out << this->get_symbol_without_use_count(*t);
if(c > 0) rval.append(op);
if(use_count) rval.append(this->get_symbol_with_use_count(*t));
else rval.append(this->get_symbol_without_use_count(*t));

++c;
}

return(out.str());
return rval;
}


Expand All @@ -79,7 +119,7 @@ namespace cpp
std::string cpp_cse::print_power(const GiNaC::ex& expr, bool use_count)
{
std::ostringstream out;
size_t n = expr.nops();
size_t n = expr.nops();

if(n != 2)
{
Expand All @@ -93,7 +133,7 @@ namespace cpp

if(GiNaC::is_a<GiNaC::numeric>(exp_generic))
{
const GiNaC::numeric& exp_numeric = GiNaC::ex_to<GiNaC::numeric>(exp_generic);
const auto& exp_numeric = GiNaC::ex_to<GiNaC::numeric>(exp_generic);

std::string sym;
if(use_count) sym = this->get_symbol_with_use_count(expr.op(0));
Expand Down
12 changes: 6 additions & 6 deletions CppTransport/translator/backends/languages/cpp/cpp_cse.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,18 @@ namespace cpp
// INTERNAL API

protected:

//! print a GiNaC expression; if use_count is set then any temporaries which are
//! used will be marked for deposition
virtual std::string print (const GiNaC::ex& expr, bool use_count) override;

std::string print(const GiNaC::ex& expr, bool use_count) override;
//! print the operands to a GiNaC expression; if use_count is set then any temporaries
//! which are used will be marked for deposition
virtual std::string print_operands(const GiNaC::ex& expr, std::string op, bool use_count) override;

std::string print_operands(const GiNaC::ex& expr, std::string op, bool use_count) override;
//! special implementation of print_operands() to print a power;
//! uses strength reduction to compute integer powers by multiplication for small enough exponents
std::string print_power (const GiNaC::ex& expr, bool use_count);
std::string print_power(const GiNaC::ex& expr, bool use_count);

};

Expand Down
2 changes: 2 additions & 0 deletions CppTransport/translator/msg_en.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ constexpr auto ERROR_SPECIES_MAPPING_OVERFLOW = "Internal error: abstract
constexpr auto ERROR_SPECIES_TO_SPECIES_FAIL = "Internal error: species-to-species map must be applied to a phase-space index";
constexpr auto ERROR_MOMENTUM_TO_SPECIES_FAIL = "Internal error: momentum-to-species map must be applied to a phase-space index";

constexpr auto ERROR_UNIMPLEMENTED_MATHS_FUNCTION = "Unimplemented mathematical function";

constexpr auto MESSAGE_HOUR_LABEL = "h";
constexpr auto MESSAGE_MINUTE_LABEL = "m";
constexpr auto MESSAGE_SECOND_LABEL = "s";
Expand Down

0 comments on commit f7d8a28

Please sign in to comment.