Skip to content

Commit db2aece

Browse files
authoredNov 11, 2024
[PROTON][NFC] Clean up code (triton-lang#5109)
Including typos, unused function, README, tool usage, and function simplification
1 parent be510cc commit db2aece

File tree

8 files changed

+26
-19
lines changed

8 files changed

+26
-19
lines changed
 

‎third_party/proton/README.md

+10-1
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,17 @@ By default, proton profiles are in the *json* format and can be read by *Hatchet
141141
pip install llnl-hatchet
142142
proton-viewer -m time/s <profile.hatchet>
143143
```
144+
144145
NOTE: `pip install hatchet` does not work because the API is slightly different.
145146

146147
### Visualizing sorted profile data
148+
147149
In addition visualizing the profile data on terminal through Hatchet. A sorted list of the kernels by the first metric can be done using the --print-sorted flag with proton-viewer
148150

149151
```bash
150152
proton-viewer -m time/ns,time/% <profile.hatchet> --print-sorted
151153
```
154+
152155
prints the sorted kernels by the time/ns since it is the first listed.
153156

154157
More options can be found by running the following command.
@@ -157,21 +160,27 @@ More options can be found by running the following command.
157160
proton-viewer -h
158161
```
159162

160-
### Advanced features
163+
## Advanced features
164+
165+
### Instrumentation (experimental)
166+
161167
In addition to profiling, Proton also incorporates MLIR/LLVM based compiler instrumentation passes to get Triton level analysis
162168
and optimization information. This feature is under active development and the list of available passes is expected to grow.
163169

164170
#### Available passes
171+
165172
print-mem-spaces: this pass prints the load and store address spaces (e.g. global, flat, shared) chosen by the compiler and attributes back to Triton source information.
166173

167174
Example usage with the Proton matmul tutorial:
175+
168176
```bash
169177
$ proton --instrument=print-mem-spaces matmul.py
170178
0 matmul_kernel matmul.py:180:20 SHARED STORE
171179
1 matmul_kernel matmul.py:181:20 SHARED STORE
172180
2 matmul_kernel matmul.py:180:20 SHARED LOAD
173181
3 matmul_kernel matmul.py:181:20 SHARED LOAD
174182
```
183+
175184
Notes: The instrument functionality is currently only available from the command line. Additionally the instrument and profile command line arguments can not be use simulantously.
176185

177186
### Instruction sampling (experimental)

‎third_party/proton/csrc/include/Profiler/GPUProfiler.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,15 @@ class GPUProfiler : public Profiler,
5858

5959
ThreadState(ConcreteProfilerT &profiler) : profiler(profiler) {}
6060

61-
void record(size_t scopeId) {
61+
size_t record() {
62+
auto scopeId = Scope::getNewScopeId();
6263
if (profiler.isOpInProgress())
63-
return;
64+
return scopeId;
6465
std::set<Data *> dataSet = profiler.getDataSet();
6566
for (auto data : dataSet)
6667
data->addScope(scopeId);
6768
profiler.correlation.apiExternIds.insert(scopeId);
69+
return scopeId;
6870
}
6971

7072
void enterOp(size_t scopeId) {

‎third_party/proton/csrc/include/Session/Session.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class SessionManager : public Singleton<SessionManager> {
7575

7676
void finalizeAllSessions(OutputFormat outputFormat);
7777

78-
void activateSession(size_t sesssionId);
78+
void activateSession(size_t sessionId);
7979

8080
void deactivateSession(size_t sessionId);
8181

@@ -97,7 +97,7 @@ class SessionManager : public Singleton<SessionManager> {
9797
const std::string &contextSourceName,
9898
const std::string &dataName);
9999

100-
void activateSessionImpl(size_t sesssionId);
100+
void activateSessionImpl(size_t sessionId);
101101

102102
void deActivateSessionImpl(size_t sessionId);
103103

@@ -135,7 +135,7 @@ class SessionManager : public Singleton<SessionManager> {
135135
// path -> session id
136136
std::map<std::string, size_t> sessionPaths;
137137
// session id -> active
138-
std::map<size_t, bool> activeSessions;
138+
std::map<size_t, bool> sessionActive;
139139
// session id -> session
140140
std::map<size_t, std::unique_ptr<Session>> sessions;
141141
// scope -> active count

‎third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,7 @@ void CuptiProfiler::CuptiProfilerPimpl::callbackFn(void *userData,
323323
static_cast<const CUpti_CallbackData *>(cbData);
324324
auto *pImpl = dynamic_cast<CuptiProfilerPimpl *>(profiler.pImpl.get());
325325
if (callbackData->callbackSite == CUPTI_API_ENTER) {
326-
auto scopeId = Scope::getNewScopeId();
327-
threadState.record(scopeId);
326+
auto scopeId = threadState.record();
328327
threadState.enterOp(scopeId);
329328
size_t numInstances = 1;
330329
if (cbId == CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch ||

‎third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,7 @@ void RoctracerProfiler::RoctracerProfilerPimpl::apiCallback(
231231
const hip_api_data_t *data = (const hip_api_data_t *)(callbackData);
232232
if (data->phase == ACTIVITY_API_PHASE_ENTER) {
233233
// Valid context and outermost level of the kernel launch
234-
auto scopeId = Scope::getNewScopeId();
235-
threadState.record(scopeId);
234+
auto scopeId = threadState.record();
236235
threadState.enterOp(scopeId);
237236
size_t numInstances = 1;
238237
if (cid == HIP_API_ID_hipGraphLaunch) {

‎third_party/proton/csrc/lib/Session/Session.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -91,20 +91,20 @@ void SessionManager::deactivateSession(size_t sessionId) {
9191

9292
void SessionManager::activateSessionImpl(size_t sessionId) {
9393
throwIfSessionNotInitialized(sessions, sessionId);
94-
if (activeSessions[sessionId])
94+
if (sessionActive[sessionId])
9595
return;
96-
activeSessions[sessionId] = true;
96+
sessionActive[sessionId] = true;
9797
sessions[sessionId]->activate();
9898
registerInterface<ScopeInterface>(sessionId, scopeInterfaceCounts);
9999
registerInterface<OpInterface>(sessionId, opInterfaceCounts);
100100
}
101101

102102
void SessionManager::deActivateSessionImpl(size_t sessionId) {
103103
throwIfSessionNotInitialized(sessions, sessionId);
104-
if (!activeSessions[sessionId]) {
104+
if (!sessionActive[sessionId]) {
105105
return;
106106
}
107-
activeSessions[sessionId] = false;
107+
sessionActive[sessionId] = false;
108108
sessions[sessionId]->deactivate();
109109
unregisterInterface<ScopeInterface>(sessionId, scopeInterfaceCounts);
110110
unregisterInterface<OpInterface>(sessionId, opInterfaceCounts);
@@ -204,7 +204,7 @@ void SessionManager::addMetrics(
204204
size_t scopeId, const std::map<std::string, MetricValueType> &metrics,
205205
bool aggregable) {
206206
std::shared_lock<std::shared_mutex> lock(mutex);
207-
for (auto [sessionId, active] : activeSessions) {
207+
for (auto [sessionId, active] : sessionActive) {
208208
if (active) {
209209
sessions[sessionId]->data->addMetrics(scopeId, metrics, aggregable);
210210
}

‎third_party/proton/proton/proton.py

-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import argparse
22
import sys
33
import os
4-
from glob import glob
54
import pathlib
65
from .profile import start, finalize, _select_backend
76
from .flags import set_command_line
@@ -91,7 +90,6 @@ def run_profiling(args, target_args):
9190

9291

9392
def run_instrumentation(args, target_args):
94-
backend = args.backend if args.backend else _select_backend()
9593
do_setup_and_execute(target_args, args.instrument)
9694

9795

‎third_party/proton/proton/viewer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def get_min_time_bytes(df, device_info):
9292
}
9393

9494
# FLOPS have a specific width to their metric
95-
default_flop_factor_dict = {f"flop/s": 1, f"gflop/s": 1e9, f"tflop/s": 1e12}
95+
default_flop_factor_dict = {"flop/s": 1, "gflop/s": 1e9, "tflop/s": 1e12}
9696
derivable_metrics.update(
9797
{key: FactorDict("flops", default_flop_factor_dict)
9898
for key in default_flop_factor_dict.keys()})
@@ -278,7 +278,7 @@ def main():
278278
type=str,
279279
default=None,
280280
help="""Exclude frames that match the given regular expression and their children.
281-
For example, the following command will exclude all paths that contain frames that contains "test":
281+
For example, the following command will exclude all paths starting from "test":
282282
```
283283
proton-viewer -e ".*test.*" path/to/file.json
284284
```

0 commit comments

Comments
 (0)