Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding task-based simulator for PCGs #1365

Open
wants to merge 59 commits into
base: repo-refactor
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
d269b40
compiler build
wmdi Oct 16, 2023
f46fd11
Merge branch 'test-substitution' into test-compiler
wmdi Oct 18, 2023
af67e9e
Merge branch 'test-substitution' into test-compiler
wmdi Nov 8, 2023
c015efb
unity dp works
wmdi Nov 15, 2023
6211b84
format
wmdi Nov 15, 2023
d9f1302
Merge remote-tracking branch 'upstream/repo-refactor' into test-compiler
wmdi Jan 24, 2024
fb58a99
fmt
wmdi Jan 24, 2024
02937e1
fix
wmdi Jan 24, 2024
6402ed0
add substitutions, compiler, and their unit tests to CI
wmdi Jan 25, 2024
0c45f61
disable runtime unit test
wmdi Jan 25, 2024
95fa427
minor fix
wmdi Feb 15, 2024
1f7e2b6
(not compilable) visitable issue for OptimalCostState
wmdi Feb 18, 2024
a9a6402
fix machine mapping hash & refactor dp algorithm
wmdi Feb 27, 2024
d8bbcb8
minor fix
wmdi Feb 27, 2024
09d3152
fix variant issue
wmdi Feb 28, 2024
a150d3a
fmt
wmdi Feb 28, 2024
2eb3fdf
fix
wmdi Mar 11, 2024
7598a92
fmt
wmdi Mar 11, 2024
05c8336
fix
wmdi Mar 14, 2024
71aeddb
Merge remote-tracking branch 'upstream/repo-refactor' into test-compiler
wmdi Mar 14, 2024
9345400
add more unit tests
wmdi Mar 18, 2024
c0015df
fmt
wmdi Mar 18, 2024
6d28697
Merge remote-tracking branch 'origin/repo-refactor' into compiler
lockshaw Mar 22, 2024
102f5fb
Fix post-merge
lockshaw Mar 22, 2024
d6e10bb
Add shell hook for sapling development
lockshaw Mar 23, 2024
95fb4cc
changed from nullopt to std::nullopt
Mar 23, 2024
c091479
fix cast issue
wmdi Mar 23, 2024
57bd35f
Merge branch 'test-compiler' of github.com:wmdi/FlexFlow into test-co…
wmdi Mar 23, 2024
54c604a
Fix spdlog cmake issue
lockshaw Mar 24, 2024
a09e528
Merge remote-tracking branch 'refs/remotes/wmdi/test-compiler' into c…
lockshaw Mar 24, 2024
8b914cf
Re-remove submodules
lockshaw Mar 24, 2024
189f323
minor fix & fmt
wmdi Mar 24, 2024
d2eb505
upd tests name to match ci
wmdi Mar 24, 2024
371324a
Add TEST_SUITE declaration to make tests findable by ctest
lockshaw Mar 26, 2024
da74817
Remove unnecessary nix files, add utils test to ci
lockshaw Mar 26, 2024
0db60db
Fix utils tests name, format
lockshaw Mar 26, 2024
6e520bb
Merge pull request #1229 from wmdi/test-compiler
wmdi Mar 26, 2024
2fd8bfe
initial draft for machine_mapping.cpp
Apr 8, 2024
ffcb8c0
Added get_successor function
Apr 8, 2024
edf7074
Machine Mapping initial draft
Apr 8, 2024
b0cf1b2
Machine Mapping initial draft
Apr 8, 2024
502c75c
Merge branch 'ff-cost-estimator' of github.com:Marsella8/FlexFlow int…
Apr 8, 2024
ae95818
Added parallel_estimate_cost function prototype
Apr 8, 2024
64d28fc
Added test draft
Apr 12, 2024
0e3cc4a
Merge remote-tracking branch 'origin/repo-refactor' into ff-cost-esti…
Apr 12, 2024
9533099
Changes file
Apr 12, 2024
9c03409
Formatting
Apr 12, 2024
2af6003
Saving
Apr 26, 2024
3a9850f
Merge branch 'repo-refactor' of github.com:flexflow/FlexFlow into ff-…
Apr 26, 2024
2793cc6
Merge branch 'repo-refactor' of github.com:flexflow/FlexFlow into ff-…
May 14, 2024
798f492
Moved cost estimator to separate file
May 14, 2024
ffa8158
Working on tests
May 15, 2024
f134e8a
Minor changes
May 15, 2024
a8e8551
Updates to cost_estimator
May 24, 2024
3a8b995
Updates to parallel_cost_estimator + implementations for machineviews…
Jun 3, 2024
cd76042
Tests + fixes for parallel cost estimator
Jun 25, 2024
94e6285
Tests + fixes for parallel_cost_estimator
Jun 27, 2024
b469d88
PR fixes
Jul 15, 2024
44ed29e
Formatting
Jul 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions lib/compiler/include/compiler/cost_estimator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef _FLEXFLOW_COMPILER_COST_ESTIMATOR_H
#define _FLEXFLOW_COMPILER_COST_ESTIMATOR_H

#include "compiler/machine_mapping.h"
#include "cost_estimate.h"
#include "pcg/machine_specification.h"
#include "pcg/machine_view.h"
#include "pcg/parallel_computation_graph.h"
#include "substitutions/sub_parallel_computation_graph.h"

using SubParallelComputationGraphView =
OutputLabelledOpenMultiDiGraphView<Operator, ParallelTensor>;

namespace FlexFlow {

float parallel_estimate_cost(
SubParallelComputationGraphView const &g,
CostEstimator const &estimator,
MachineMapping const &device_mapping,
std::unordered_map<InputMultiDiEdge, MachineView> const
&frontier_machine_views);

} // namespace FlexFlow

#endif
130 changes: 130 additions & 0 deletions lib/compiler/src/cost_estimator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#include "compiler/cost_estimator.h"
#include "compiler/cost_estimate.h"
#include "compiler/machine_mapping.h"
#include "pcg/parallel_computation_graph.h"
#include "utils/deduplicated_priority_queue.h"
#include "utils/exception.h"
#include "utils/graph/serialparallel.h"
#include <algorithm>

namespace FlexFlow {

// Computes estimated execution cost for a single node
float node_estimate_cost(Node const &node,
SubParallelComputationGraphView const &g,
CostEstimator const &estimator,
MachineMapping const &device_mapping) {
std::unordered_set<UpwardOpenMultiDiEdge> incoming_edges =
get_incoming_edges(g, node);

std::vector<ParallelTensorShape> inputs = transform(
as_vector(incoming_edges), [&](UpwardOpenMultiDiEdge const &input_edge) {
return g.at(input_edge).get_shape();
});
float cost = estimator.estimate_cost(
g.at(node).attrs, inputs, device_mapping.machine_views.at(node));
return cost;
}

struct TimedNode { // Node and associated finishing time
Node node;
req<float> endtime;
};
FF_VISITABLE_STRUCT(TimedNode, node, endtime);

struct TimeComparison {
bool operator()(TimedNode const &lhs, TimedNode const &rhs) const {
return (lhs.endtime < rhs.endtime);
}
};

bool predecessors_have_been_processed(
std::unordered_set<Node> const &predecessors,
std::unordered_set<TimedNode> processed) {
std::unordered_set<Node> simple_processed =
transform(processed, [](TimedNode const &tn) { return tn.node; });

return all_of(predecessors, [&simple_processed](Node p) {
return simple_processed.find(p) != simple_processed.end();
});
}

std::vector<device_id_t> get_devices(Node const &node,
MachineMapping const &device_mapping) {
return device_mapping.machine_views.at(node).device_ids();
}

float parallel_estimate_cost(
SubParallelComputationGraphView const &g,
CostEstimator const &estimator,
MachineMapping const &device_mapping,
std::unordered_map<InputMultiDiEdge, MachineView> const
&frontier_machine_views) {
float current_time = 0;
std::unordered_set<Node>
frontier; // nodes whose dependencies (previous nodes) have been met, and
// are waiting to be processed.
DeduplicatedPriorityQueue<TimedNode, std::vector<TimedNode>, TimeComparison>
processing; // nodes currently being processed.
std::unordered_set<TimedNode>
processed; // set of nodes that have already been processed
std::unordered_map<device_id_t, bool>
occupied; // keeps track of the devices that are currently occupied

// Filling the frontier
for (auto const &[edge, _] : frontier_machine_views) {
Node node = get_dst_node(edge);
frontier.insert(node);
}

auto start_node_processing = [&](Node const &node,
std::vector<device_id_t> const &devices) {
float cost = node_estimate_cost(node, g, estimator, device_mapping);
processing.push({node, current_time + cost});
for (device_id_t d : devices) {
occupied[d] = true;
}
frontier.erase(node);
};

auto finish_node_processing = [&](TimedNode const &finished) {
std::vector<device_id_t> devices =
get_devices(finished.node, device_mapping);
for (device_id_t d : devices) { // free devices
occupied[d] = false;
}
processed.insert(finished);
current_time = finished.endtime;
};

while (!frontier.empty() || !processing.empty()) {
// Processing new nodes
std::unordered_set<Node> frontier_copy(frontier);
for (Node const &node : frontier_copy) {
std::vector<device_id_t> devices = get_devices(node, device_mapping);
if (all_of(devices,
[&occupied](device_id_t d) { return occupied[d] == false; })) {
start_node_processing(node, devices);
}
}

// Finish processing all nodes
while (!processing.empty()) {
TimedNode finished = processing.top();
processing.pop();
finish_node_processing(finished);

// Adding candidates to the frontier
for (Node const &successor : get_successors(g, finished.node)) {
std::unordered_set<Node> predecessors = get_predecessors(g, successor);

if (predecessors_have_been_processed(predecessors, processed)) {

frontier.insert(successor);
}
}
}
}
return current_time;
}
} // namespace FlexFlow
3 changes: 3 additions & 0 deletions lib/compiler/src/machine_mapping.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
#include "utils/exception.h"
#include "utils/graph/serialparallel.h"

#include "utils/deduplicated_priority_queue.h"
#include <algorithm>

namespace FlexFlow {

MachineMapping MachineMapping::combine(MachineMapping const &s1,
Expand Down
1 change: 1 addition & 0 deletions lib/compiler/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ ff_add_test_executable(
compiler
doctest
utils-test-common
pcg
)
Loading