Skip to content

Commit

Permalink
MessageSpatial2D::In::wrap()
Browse files Browse the repository at this point in the history
Needs better testing, will post model I was playing with in PR comments.
  • Loading branch information
Robadob committed Feb 9, 2022
1 parent 949aba8 commit 0c5a64a
Showing 1 changed file with 322 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,249 @@ class MessageSpatial2D::In {
detail::curve::Curve::NamespaceHash combined_hash;
};
/**
* Constructer
* This class is created when a search origin is provided to MessageSpatial2D::In::wrap()(float, float)
* It provides iterator access to a subset of the full message list, according to the provided search origin
*
* @see MessageSpatial2D::In::wrap()(float, float)
*/
class WrapFilter {
public:
/**
* Provides access to a specific message
* Returned by the iterator
* @see In::Filter::iterator
*/
class Message {
/**
* Paired Filter class which created the iterator
*/
const WrapFilter& _parent;
/**
* Relative cell within the Moore neighbourhood
* relative_cell[0] corresponds to x offset
* relative_cell[1] corresponds to y offset
*/
int relative_cell[2] = { -2, 1 };
/**
* This is the index after the final message, relative to the full message list, in the current bin
*/
int cell_index_max = 0;
/**
* This is the index of the currently accessed message, relative to the full message list
*/
int cell_index = 0;
/**
* The pre-calculated boundary wrapping distance to the message
*/
float distance;
__device__ float wrapped_distance(const float &x, const float &y) const {
// distance from x, y
// To _parent.loc
// Wrapping over boundaries
// https://blog.demofox.org/2017/10/01/calculating-the-distance-between-points-in-wrap-around-toroidal-space/
// Note, this falls over if either location is outside of [0, environmentWidth]
// That could be fixed moving out of bounds locations in bounds (+width until positive, then remainder op)
float dx = abs(_parent.loc[0] - x);
float dy = abs(_parent.loc[1] - y);

if (dx > _parent.metadata->environmentWidth[0] / 2.0f)
dx = _parent.metadata->environmentWidth[0] - dx;

if (dy > _parent.metadata->environmentWidth[1] / 2.0f)
dy = _parent.metadata->environmentWidth[1] - dy;

return sqrt(dx * dx + dy * dy);
}

public:
/**
* Constructs a message and directly initialises all of it's member variables
* @note See member variable documentation for their purposes
*/
__device__ Message(const WrapFilter& parent, const int relative_cell_x, const int relative_cell_y, const int& _cell_index_max, const int& _cell_index)
: _parent(parent)
, cell_index_max(_cell_index_max)
, cell_index(_cell_index) {
relative_cell[0] = relative_cell_x;
relative_cell[1] = relative_cell_y;
}
/**
* False minimal constructor used by iterator::end()
*/
__device__ Message(const WrapFilter& parent)
: _parent(parent) { }
/**
* Equality operator
* Compares all internal member vars for equality
* @note Does not compare _parent
*/
__device__ bool operator==(const Message& rhs) const {
return this->relative_cell == rhs.relative_cell
&& this->cell_index_max == rhs.cell_index_max
&& this->cell_index == rhs.cell_index;
}
/**
* This should only be called to compare against end()
* It has been modified to check for end of iteration with minimal instructions
* Therefore it does not even perform the equality operation
* @note Use operator==() if proper equality is required
*/
__device__ bool operator!=(const Message& rhs) const {
// The incoming Message& is end(), so we don't care about that
// We only care that the host object has reached end
// When the x offset equals 2, it has exceeded the [-1, 1] range
return !(this->relative_cell[0] >= 2);
}
/**
* Updates the message to return variables from the next message in the message list
* @return Returns itself
*/
__device__ Message& operator++();
/**
* Utility function for deciding next strip to access
*/
__device__ void nextCell() {
if (relative_cell[1] >= 1) {
relative_cell[1] = -1;
relative_cell[0]++;
} else {
relative_cell[1]++;
}
}
/**
* Returns the value for the current message attached to the named variable
* @param variable_name Name of the variable
* @tparam T type of the variable
* @tparam N Length of variable name (this should be implicit if a string literal is passed to variable name)
* @return The specified variable, else 0x0 if an error occurs
*/
template<typename T, unsigned int N>
__device__ T getVariable(const char(&variable_name)[N]) const;
/**
* Returns the specified variable array element from the current message attached to the named variable
* @param variable_name name used for accessing the variable, this value should be a string literal e.g. "foobar"
* @param index Index of the element within the variable array to return
* @tparam T Type of the message variable being accessed
* @tparam N The length of the array variable, as set within the model description hierarchy
* @tparam M Length of variable_name, this should always be implicit if passing a string literal
* @throws exception::DeviceError If name is not a valid variable within the agent (flamegpu must be built with SEATBELTS enabled for device error checking)
* @throws exception::DeviceError If T is not the type of variable 'name' within the message (flamegpu must be built with SEATBELTS enabled for device error checking)
* @throws exception::DeviceError If index is out of bounds for the variable array specified by name (flamegpu must be built with SEATBELTS enabled for device error checking)
*/
template<typename T, MessageNone::size_type N, unsigned int M>
__device__ T getVariable(const char(&variable_name)[M], const unsigned int& index) const;
/**
* Returns the wrapped distance from the search origin to the current message
*/
__device__ float getDistance() const {
return distance;
}
};
/**
* Stock iterator for iterating MessageSpatial3D::In::Filter::Message objects
*/
class iterator {
/**
* The message returned to the user
*/
Message _message;

public:
/**
* Constructor
* This iterator is constructed by MessageSpatial2D::In::WrapFilter::begin()(float, float)
* @see MessageSpatial2D::In::wrap()(float, float)
*/
__device__ iterator(const WrapFilter& parent, const int& relative_cell_x, const int& relative_cell_y, const int& _cell_index_max, const int& _cell_index)
: _message(parent, relative_cell_x, relative_cell_y, _cell_index_max, _cell_index) {
// Increment to find first message
++_message;
}
/**
* False constructor
* Only used by WrapFilter::end(), creates a null object
*/
__device__ iterator(const WrapFilter& parent)
: _message(parent) { }
/**
* Moves to the next message
* (Prefix increment operator)
*/
__device__ iterator& operator++() { ++_message; return *this; }
/**
* Moves to the next message
* (Postfix increment operator, returns value prior to increment)
*/
__device__ iterator operator++(int) {
iterator temp = *this;
++* this;
return temp;
}
/**
* Equality operator
* Compares message
*/
__device__ bool operator==(const iterator& rhs) const { return _message == rhs._message; }
/**
* Inequality operator
* Compares message
*/
__device__ bool operator!=(const iterator& rhs) const { return _message != rhs._message; }
/**
* Dereferences the iterator to return the message object, for accessing variables
*/
__device__ Message& operator*() { return _message; }
/**
* Dereferences the iterator to return the message object, for accessing variables
*/
__device__ Message* operator->() { return &_message; }
};
/**
* Constructor, takes the search parameters requried
* @param _metadata Pointer to message list metadata
* @param combined_hash agentfn+message hash for accessing message data
* @param x Search origin x coord
* @param y Search origin y coord
*/
__device__ WrapFilter(const MetaData* _metadata, const detail::curve::Curve::NamespaceHash& combined_hash, const float& x, const float& y);
/**
* Returns an iterator to the start of the message list subset about the search origin
*/
inline __device__ iterator begin(void) const {
// Bin before initial bin, as the constructor calls increment operator
return iterator(*this, -2, 1, 1, 0);
}
/**
* Returns an iterator to the position beyond the end of the message list subset
* @note This iterator is the same for all message list subsets
*/
inline __device__ iterator end(void) const {
// Empty init, because this object is never used
// iterator equality doesn't actually check the end object
return iterator(*this);
}

private:
/**
* Search origin
*/
float loc[2];
/**
* Search origin's grid cell
*/
GridPos2D cell;
/**
* Pointer to message list metadata, e.g. environment bounds, search radius, PBM location
*/
const MetaData* metadata;
/**
* CURVE hash for accessing message data
* agent function hash + message hash
*/
detail::curve::Curve::NamespaceHash combined_hash;
};
/**
* Constructor
* Initialises member variables
* @param agentfn_hash Added to message_hash to produce combined_hash
* @param message_hash Added to agentfn_hash to produce combined_hash
Expand All @@ -238,10 +480,24 @@ class MessageSpatial2D::In {
*
* @param x Search origin x coord
* @param y Search origin y coord
*
* @note This iterator may return messages outside of the search radius, so distance checking should be performed by the user
*/
inline __device__ Filter operator() (const float &x, const float &y) const {
return Filter(metadata, combined_hash, x, y);
}
/**
* Returns a WrapFilter object which provides access to message iterator
* for iterating a subset of messages including those within the radius of the search origin
*
* @param x Search origin x coord
* @param y Search origin y coord
*
* @note Unlike the regular iterator, this iterator will not return messages outside of the search radius. The wrapped distance can be returned via WrapFilter::Message::distance()
*/
inline __device__ WrapFilter wrap(const float& x, const float& y) const {
return WrapFilter(metadata, combined_hash, x, y);
}

/**
* Returns the search radius of the message list defined in the model description
Expand Down Expand Up @@ -314,6 +570,32 @@ T MessageSpatial2D::In::Filter::Message::getVariable(const char(&variable_name)[
T value = detail::curve::Curve::getMessageArrayVariable<T, N>(variable_name, this->_parent.combined_hash, cell_index, array_index);
return value;
}
template<typename T, unsigned int N>
__device__ T MessageSpatial2D::In::WrapFilter::Message::getVariable(const char(&variable_name)[N]) const {
#if !defined(SEATBELTS) || SEATBELTS
// Ensure that the message is within bounds.
if (relative_cell[0] >= 2) {
DTHROW("MessageSpatial2D in invalid bin, unable to get variable '%s'.\n", variable_name);
return static_cast<T>(0);
}
#endif
// get the value from curve using the stored hashes and message index.
T value = detail::curve::Curve::getMessageVariable<T>(variable_name, this->_parent.combined_hash, cell_index);
return value;
}
template<typename T, MessageNone::size_type N, unsigned int M> __device__
T MessageSpatial2D::In::WrapFilter::Message::getVariable(const char(&variable_name)[M], const unsigned int& array_index) const {
#if !defined(SEATBELTS) || SEATBELTS
// Ensure that the message is within bounds.
if (relative_cell[0] >= 2) {
DTHROW("MessageSpatial2D in invalid bin, unable to get variable '%s'.\n", variable_name);
return {};
}
#endif
// get the value from curve using the stored hashes and message index.
T value = detail::curve::Curve::getMessageArrayVariable<T, N>(variable_name, this->_parent.combined_hash, cell_index, array_index);
return value;
}


__device__ __forceinline__ MessageSpatial2D::GridPos2D getGridPosition2D(const MessageSpatial2D::MetaData *md, float x, float y) {
Expand Down Expand Up @@ -385,6 +667,45 @@ __device__ inline MessageSpatial2D::In::Filter::Message& MessageSpatial2D::In::F
}
return *this;
}
__device__ inline MessageSpatial2D::In::WrapFilter::WrapFilter(const MetaData* _metadata, const detail::curve::Curve::NamespaceHash& _combined_hash, const float& x, const float& y)
: metadata(_metadata)
, combined_hash(_combined_hash) {
loc[0] = x;
loc[1] = y;
cell = getGridPosition2D(_metadata, x, y);
}
__device__ inline MessageSpatial2D::In::WrapFilter::Message& MessageSpatial2D::In::WrapFilter::Message::operator++() {
do {
cell_index++;
bool move_cell = cell_index >= cell_index_max;
while (move_cell) {
nextCell();
cell_index = 0;
cell_index_max = 1;
if (relative_cell[0] < 2) {
// Wrap the cell (simply add grid width and use remainder op, relative should not be less than - grid width)
int absolute_cell_x = (_parent.cell.x + relative_cell[0] + static_cast<int>(_parent.metadata->gridDim[0])) % _parent.metadata->gridDim[0];
int absolute_cell_y = (_parent.cell.y + relative_cell[1] + static_cast<int>(_parent.metadata->gridDim[1])) % _parent.metadata->gridDim[1];
unsigned int start_hash = getHash2D(_parent.metadata, { absolute_cell_x, absolute_cell_y });
// Lookup start and end indicies from PBM
cell_index = _parent.metadata->PBM[start_hash];
cell_index_max = _parent.metadata->PBM[start_hash + 1];
}
move_cell = cell_index >= cell_index_max;
}
// If message is out of bounds, break
if (relative_cell[0] >= 2) {
distance = 0;
break;
}
// Else, fetch it's location and update distance
const float msg_x = getVariable<float>("x");
const float msg_y = getVariable<float>("y");
distance = wrapped_distance(msg_x, msg_y);
} while (distance > _parent.metadata->radius);
return *this;
}


} // namespace flamegpu

Expand Down

0 comments on commit 0c5a64a

Please sign in to comment.