Skip to content

Commit

Permalink
Port changes from the main repository (mostly related to large blocks)
Browse files Browse the repository at this point in the history
Skip neighbor list for very small systems

    openmm#4070

Store bounding box sizes in half precision

    openmm@2ae50f9

Use large blocks to optimize building the neighbor list

    openmm@3955033

Improved sorting of blocks when building neighbor list

    openmm@796ffaa

Fixed bug in large blocks optimization with triclinic boxes

    openmm@4c10732

Optimize sorting of non-uniformly distributed data

    openmm@71d9bb1

Co-authored-by: bdenhollander <44237618+bdenhollander@users.noreply.github.com>
  • Loading branch information
ex-rzr and bdenhollander committed Aug 25, 2024
1 parent 5d4b462 commit aa4a823
Show file tree
Hide file tree
Showing 7 changed files with 300 additions and 58 deletions.
18 changes: 13 additions & 5 deletions platforms/hip/include/HipNonbondedUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2022 Stanford University and the Authors. *
* Portions copyright (c) 2009-2023 Stanford University and the Authors. *
* Portions copyright (C) 2020-2023 Advanced Micro Devices, Inc. All Rights *
* Reserved. *
* Authors: Peter Eastman, Nicholas Curtis *
Expand Down Expand Up @@ -83,8 +83,10 @@ class OPENMM_EXPORT_COMMON HipNonbondedUtilities : public NonbondedUtilities {
* @param exclusionList for each atom, specifies the list of other atoms whose interactions should be excluded
* @param kernel the code to evaluate the interaction
* @param forceGroup the force group in which the interaction should be calculated
* @param usesNeighborList specifies whether a neighbor list should be used to optimize this interaction. This should
* be viewed as only a suggestion. Even when it is false, a neighbor list may be used anyway.
*/
void addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const std::vector<std::vector<int> >& exclusionList, const std::string& kernel, int forceGroup);
void addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const std::vector<std::vector<int> >& exclusionList, const std::string& kernel, int forceGroup, bool usesNeighborList = true);
/**
* Add a nonbonded interaction to be evaluated by the default interaction kernel.
*
Expand All @@ -95,9 +97,11 @@ class OPENMM_EXPORT_COMMON HipNonbondedUtilities : public NonbondedUtilities {
* @param exclusionList for each atom, specifies the list of other atoms whose interactions should be excluded
* @param kernel the code to evaluate the interaction
* @param forceGroup the force group in which the interaction should be calculated
* @param usesNeighborList specifies whether a neighbor list should be used to optimize this interaction. This should
* be viewed as only a suggestion. Even when it is false, a neighbor list may be used anyway.
* @param supportsPairList specifies whether this interaction can work with a neighbor list that uses a separate pair list
*/
void addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const std::vector<std::vector<int> >& exclusionList, const std::string& kernel, int forceGroup, bool supportsPairList);
void addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const std::vector<std::vector<int> >& exclusionList, const std::string& kernel, int forceGroup, bool usesNeighborList, bool supportsPairList);
/**
* Add a per-atom parameter that the default interaction kernel may depend on.
*/
Expand Down Expand Up @@ -336,20 +340,23 @@ class OPENMM_EXPORT_COMMON HipNonbondedUtilities : public NonbondedUtilities {
HipArray sortedBlocks;
HipArray sortedBlockCenter;
HipArray sortedBlockBoundingBox;
HipArray blockSizeRange;
HipArray largeBlockCenter;
HipArray largeBlockBoundingBox;
HipArray oldPositions;
HipArray rebuildNeighborList;
HipSort* blockSorter;
hipEvent_t downloadCountEvent;
unsigned int* pinnedCountBuffer;
std::vector<void*> forceArgs, findBlockBoundsArgs, sortBoxDataArgs, findInteractingBlocksArgs, copyInteractionCountsArgs;
std::vector<void*> forceArgs, findBlockBoundsArgs, computeSortKeysArgs, sortBoxDataArgs, findInteractingBlocksArgs, copyInteractionCountsArgs;
std::vector<std::vector<int> > atomExclusions;
std::vector<ParameterInfo> parameters;
std::vector<ParameterInfo> arguments;
std::vector<std::string> energyParameterDerivatives;
std::map<int, double> groupCutoff;
std::map<int, std::string> groupKernelSource;
double lastCutoff;
bool useCutoff, usePeriodic, anyExclusions, usePadding, forceRebuildNeighborList, canUsePairList;
bool useCutoff, usePeriodic, anyExclusions, usePadding, useNeighborList, forceRebuildNeighborList, canUsePairList, useLargeBlocks;
int startTileIndex, startBlockIndex, numBlocks, numTilesInBatch, maxExclusions;
int numForceThreadBlocks, forceThreadBlockSize, findInteractingBlocksThreadBlockSize, numAtoms, groupFlags;
unsigned int maxTiles, maxSinglePairs, tilesAfterReorder;
Expand All @@ -368,6 +375,7 @@ class HipNonbondedUtilities::KernelSet {
std::string source;
hipFunction_t forceKernel, energyKernel, forceEnergyKernel;
hipFunction_t findBlockBoundsKernel;
hipFunction_t computeSortKeysKernel;
hipFunction_t sortBoxDataKernel;
hipFunction_t findInteractingBlocksKernel;
hipFunction_t copyInteractionCountsKernel;
Expand Down
2 changes: 1 addition & 1 deletion platforms/hip/src/HipKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,7 @@ void HipCalcNonbondedForceKernel::initialize(const System& system, const Nonbond
}
source = cu.replaceStrings(source, replacements);
if (force.getIncludeDirectSpace())
cu.getNonbondedUtilities().addInteraction(useCutoff, usePeriodic, true, force.getCutoffDistance(), exclusionList, source, force.getForceGroup(), true);
cu.getNonbondedUtilities().addInteraction(useCutoff, usePeriodic, true, force.getCutoffDistance(), exclusionList, source, force.getForceGroup(), numParticles > 3000, true);

// Initialize the exceptions.

Expand Down
93 changes: 67 additions & 26 deletions platforms/hip/src/HipNonbondedUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2022 Stanford University and the Authors. *
* Portions copyright (c) 2009-2023 Stanford University and the Authors. *
* Portions copyright (C) 2020-2023 Advanced Micro Devices, Inc. All Rights *
* Reserved. *
* Authors: Peter Eastman, Nicholas Curtis *
Expand Down Expand Up @@ -51,21 +51,18 @@ using namespace std;

class HipNonbondedUtilities::BlockSortTrait : public HipSort::SortTrait {
public:
BlockSortTrait(bool useDouble) : useDouble(useDouble) {
}
int getDataSize() const {return useDouble ? sizeof(double2) : sizeof(float2);}
int getKeySize() const {return useDouble ? sizeof(double) : sizeof(float);}
const char* getDataType() const {return "real2";}
const char* getKeyType() const {return "real";}
const char* getMinKey() const {return "-3.40282e+38f";}
const char* getMaxKey() const {return "3.40282e+38f";}
const char* getMaxValue() const {return "make_real2(3.40282e+38f, 3.40282e+38f)";}
const char* getSortKey() const {return "value.x";}
private:
bool useDouble;
BlockSortTrait() {}
int getDataSize() const {return sizeof(int);}
int getKeySize() const {return sizeof(int);}
const char* getDataType() const {return "unsigned int";}
const char* getKeyType() const {return "unsigned int";}
const char* getMinKey() const {return "0";}
const char* getMaxKey() const {return "0xFFFFFFFFu";}
const char* getMaxValue() const {return "0xFFFFFFFFu";}
const char* getSortKey() const {return "value";}
};

HipNonbondedUtilities::HipNonbondedUtilities(HipContext& context) : context(context), useCutoff(false), usePeriodic(false), anyExclusions(false), usePadding(true),
HipNonbondedUtilities::HipNonbondedUtilities(HipContext& context) : context(context), useCutoff(false), usePeriodic(false), useNeighborList(false), anyExclusions(false), usePadding(true),
blockSorter(NULL), pinnedCountBuffer(NULL), forceRebuildNeighborList(true), lastCutoff(0.0), groupFlags(0), canUsePairList(true), tilesAfterReorder(0) {
// Decide how many thread blocks to use.

Expand All @@ -75,6 +72,13 @@ HipNonbondedUtilities::HipNonbondedUtilities(HipContext& context) : context(cont
numForceThreadBlocks = 5*4*context.getMultiprocessors();
forceThreadBlockSize = 64;
findInteractingBlocksThreadBlockSize = context.getSIMDWidth();

// When building the neighbor list, we can optionally use large blocks (32 * warpSize atoms) to
// accelerate the process. This makes building the neighbor list faster, but it prevents
// us from sorting atom blocks by size, which leads to a slightly less efficient neighbor
// list. We guess based on system size which will be faster.

useLargeBlocks = (context.getNumAtoms() > 90000);
setKernelSource(HipKernelSources::nonbonded);
}

Expand All @@ -86,11 +90,11 @@ HipNonbondedUtilities::~HipNonbondedUtilities() {
hipEventDestroy(downloadCountEvent);
}

void HipNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const vector<vector<int> >& exclusionList, const string& kernel, int forceGroup) {
addInteraction(usesCutoff, usesPeriodic, usesExclusions, cutoffDistance, exclusionList, kernel, forceGroup, false);
void HipNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const vector<vector<int> >& exclusionList, const string& kernel, int forceGroup, bool usesNeighborList) {
addInteraction(usesCutoff, usesPeriodic, usesExclusions, cutoffDistance, exclusionList, kernel, forceGroup, usesNeighborList, false);
}

void HipNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const vector<vector<int> >& exclusionList, const string& kernel, int forceGroup, bool supportsPairList) {
void HipNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const vector<vector<int> >& exclusionList, const string& kernel, int forceGroup, bool usesNeighborList, bool supportsPairList) {
if (groupCutoff.size() > 0) {
if (usesCutoff != useCutoff)
throw OpenMMException("All Forces must agree on whether to use a cutoff");
Expand All @@ -103,6 +107,7 @@ void HipNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, b
requestExclusions(exclusionList);
useCutoff = usesCutoff;
usePeriodic = usesPeriodic;
useNeighborList |= (usesNeighborList && useCutoff);
groupCutoff[forceGroup] = cutoffDistance;
groupFlags |= 1<<forceGroup;
canUsePairList &= supportsPairList;
Expand Down Expand Up @@ -291,15 +296,23 @@ void HipNonbondedUtilities::initialize(const System& system) {
int elementSize = (context.getUseDoublePrecision() ? sizeof(double) : sizeof(float));
blockCenter.initialize(context, numAtomBlocks, 4*elementSize, "blockCenter");
blockBoundingBox.initialize(context, numAtomBlocks, 4*elementSize, "blockBoundingBox");
sortedBlocks.initialize(context, numAtomBlocks, 2*elementSize, "sortedBlocks");
sortedBlocks.initialize<unsigned int>(context, numAtomBlocks, "sortedBlocks");
sortedBlockCenter.initialize(context, numAtomBlocks+1, 4*elementSize, "sortedBlockCenter");
sortedBlockBoundingBox.initialize(context, numAtomBlocks+1, 4*elementSize, "sortedBlockBoundingBox");
blockSizeRange.initialize(context, 2, elementSize, "blockSizeRange");
largeBlockCenter.initialize(context, numAtomBlocks, 4*elementSize, "largeBlockCenter");
largeBlockBoundingBox.initialize(context, numAtomBlocks*4, elementSize, "largeBlockBoundingBox");
oldPositions.initialize(context, numAtoms, 4*elementSize, "oldPositions");
rebuildNeighborList.initialize<int>(context, 1, "rebuildNeighborList");
blockSorter = new HipSort(context, new BlockSortTrait(context.getUseDoublePrecision()), numAtomBlocks, false);
blockSorter = new HipSort(context, new BlockSortTrait(), numAtomBlocks, false);
vector<unsigned int> count(2, 0);
interactionCount.upload(count);
rebuildNeighborList.upload(&count[0]);
if (context.getUseDoublePrecision()) {
blockSizeRange.upload(vector<double>{1e38, 0});
} else {
blockSizeRange.upload(vector<float>{1e38, 0});
}
}

// Record arguments for kernels.
Expand Down Expand Up @@ -343,17 +356,30 @@ void HipNonbondedUtilities::initialize(const System& system) {
findBlockBoundsArgs.push_back(&blockCenter.getDevicePointer());
findBlockBoundsArgs.push_back(&blockBoundingBox.getDevicePointer());
findBlockBoundsArgs.push_back(&rebuildNeighborList.getDevicePointer());
findBlockBoundsArgs.push_back(&sortedBlocks.getDevicePointer());
findBlockBoundsArgs.push_back(&blockSizeRange.getDevicePointer());
computeSortKeysArgs.push_back(&blockBoundingBox.getDevicePointer());
computeSortKeysArgs.push_back(&sortedBlocks.getDevicePointer());
computeSortKeysArgs.push_back(&blockSizeRange.getDevicePointer());
sortBoxDataArgs.push_back(&sortedBlocks.getDevicePointer());
sortBoxDataArgs.push_back(&blockCenter.getDevicePointer());
sortBoxDataArgs.push_back(&blockBoundingBox.getDevicePointer());
sortBoxDataArgs.push_back(&sortedBlockCenter.getDevicePointer());
sortBoxDataArgs.push_back(&sortedBlockBoundingBox.getDevicePointer());
if (useLargeBlocks) {
sortBoxDataArgs.push_back(&largeBlockCenter.getDevicePointer());
sortBoxDataArgs.push_back(&largeBlockBoundingBox.getDevicePointer());
sortBoxDataArgs.push_back(context.getPeriodicBoxSizePointer());
sortBoxDataArgs.push_back(context.getInvPeriodicBoxSizePointer());
sortBoxDataArgs.push_back(context.getPeriodicBoxVecXPointer());
sortBoxDataArgs.push_back(context.getPeriodicBoxVecYPointer());
sortBoxDataArgs.push_back(context.getPeriodicBoxVecZPointer());
}
sortBoxDataArgs.push_back(&context.getPosq().getDevicePointer());
sortBoxDataArgs.push_back(&oldPositions.getDevicePointer());
sortBoxDataArgs.push_back(&interactionCount.getDevicePointer());
sortBoxDataArgs.push_back(&rebuildNeighborList.getDevicePointer());
sortBoxDataArgs.push_back(&forceRebuildNeighborList);
sortBoxDataArgs.push_back(&blockSizeRange.getDevicePointer());
findInteractingBlocksArgs.push_back(context.getPeriodicBoxSizePointer());
findInteractingBlocksArgs.push_back(context.getInvPeriodicBoxSizePointer());
findInteractingBlocksArgs.push_back(context.getPeriodicBoxVecXPointer());
Expand All @@ -371,6 +397,10 @@ void HipNonbondedUtilities::initialize(const System& system) {
findInteractingBlocksArgs.push_back(&sortedBlocks.getDevicePointer());
findInteractingBlocksArgs.push_back(&sortedBlockCenter.getDevicePointer());
findInteractingBlocksArgs.push_back(&sortedBlockBoundingBox.getDevicePointer());
if (useLargeBlocks) {
findInteractingBlocksArgs.push_back(&largeBlockCenter.getDevicePointer());
findInteractingBlocksArgs.push_back(&largeBlockBoundingBox.getDevicePointer());
}
findInteractingBlocksArgs.push_back(&exclusionIndices.getDevicePointer());
findInteractingBlocksArgs.push_back(&exclusionRowIndices.getDevicePointer());
findInteractingBlocksArgs.push_back(&oldPositions.getDevicePointer());
Expand All @@ -397,23 +427,24 @@ void HipNonbondedUtilities::prepareInteractions(int forceGroups) {
return;
if (groupKernels.find(forceGroups) == groupKernels.end())
createKernelsForGroups(forceGroups);
if (!useCutoff)
return;
if (numTiles == 0)
return;
KernelSet& kernels = groupKernels[forceGroups];
if (usePeriodic) {
if (useCutoff && usePeriodic) {
double4 box = context.getPeriodicBoxSize();
double minAllowedSize = 1.999999*kernels.cutoffDistance;
if (box.x < minAllowedSize || box.y < minAllowedSize || box.z < minAllowedSize)
throw OpenMMException("The periodic box size has decreased to less than twice the nonbonded cutoff.");
}
if (!useNeighborList)
return;
if (numTiles == 0)
return;

// Compute the neighbor list.

if (lastCutoff != kernels.cutoffDistance)
forceRebuildNeighborList = true;
context.executeKernelFlat(kernels.findBlockBoundsKernel, &findBlockBoundsArgs[0], context.getPaddedNumAtoms(), context.getSIMDWidth());
context.executeKernelFlat(kernels.computeSortKeysKernel, &computeSortKeysArgs[0], context.getNumAtomBlocks());
blockSorter->sort(sortedBlocks);
context.executeKernelFlat(kernels.sortBoxDataKernel, &sortBoxDataArgs[0], context.getNumAtoms(), 64);
context.executeKernelFlat(kernels.findInteractingBlocksKernel, &findInteractingBlocksArgs[0], context.getNumAtomBlocks() * context.getSIMDWidth() * numTilesInBatch, findInteractingBlocksThreadBlockSize);
Expand All @@ -433,7 +464,7 @@ void HipNonbondedUtilities::computeInteractions(int forceGroups, bool includeFor
kernel = createInteractionKernel(kernels.source, parameters, arguments, true, true, forceGroups, includeForces, includeEnergy);
context.executeKernelFlat(kernel, &forceArgs[0], numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize);
}
if (useCutoff && numTiles > 0) {
if (useNeighborList && numTiles > 0) {
hipEventSynchronize(downloadCountEvent);
updateNeighborListSize();
}
Expand Down Expand Up @@ -522,6 +553,8 @@ void HipNonbondedUtilities::createKernelsForGroups(int groups) {
defines["USE_PERIODIC"] = "1";
if (context.getBoxIsTriclinic())
defines["TRICLINIC"] = "1";
if (useLargeBlocks)
defines["USE_LARGE_BLOCKS"] = "1";
defines["MAX_EXCLUSIONS"] = context.intToString(maxExclusions);
int maxBits = 0;
if (canUsePairList) {
Expand Down Expand Up @@ -550,8 +583,14 @@ void HipNonbondedUtilities::createKernelsForGroups(int groups) {
defines["MAX_BITS_FOR_PAIRS"] = context.intToString(maxBits);
defines["NUM_TILES_IN_BATCH"] = context.intToString(numTilesInBatch);
defines["GROUP_SIZE"] = context.intToString(findInteractingBlocksThreadBlockSize);
int binShift = 1;
while (1<<binShift <= context.getNumAtomBlocks())
binShift++;
defines["BIN_SHIFT"] = context.intToString(binShift);
defines["BLOCK_INDEX_MASK"] = context.intToString((1<<binShift)-1);
hipModule_t interactingBlocksProgram = context.createModule(HipKernelSources::vectorOps+HipKernelSources::findInteractingBlocks, defines);
kernels.findBlockBoundsKernel = context.getKernel(interactingBlocksProgram, "findBlockBounds");
kernels.computeSortKeysKernel = context.getKernel(interactingBlocksProgram, "computeSortKeys");
kernels.sortBoxDataKernel = context.getKernel(interactingBlocksProgram, "sortBoxData");
kernels.findInteractingBlocksKernel = context.getKernel(interactingBlocksProgram, "findBlocksWithInteractions");
kernels.copyInteractionCountsKernel = context.getKernel(interactingBlocksProgram, "copyInteractionCounts");
Expand Down Expand Up @@ -670,6 +709,8 @@ hipFunction_t HipNonbondedUtilities::createInteractionKernel(const string& sourc
defines["USE_EXCLUSIONS"] = "1";
if (isSymmetric)
defines["USE_SYMMETRIC"] = "1";
if (useNeighborList)
defines["USE_NEIGHBOR_LIST"] = "1";
defines["ENABLE_SHUFFLE"] = "1"; // Used only in hippoNonbonded.cc
if (includeForces)
defines["INCLUDE_FORCES"] = "1";
Expand Down
Loading

0 comments on commit aa4a823

Please sign in to comment.