Skip to content

Commit

Permalink
Make session_handle of DirectSession an unique identifier and accessi…
Browse files Browse the repository at this point in the history
…ble from OptimizationPassRegistry::PRE_PLACEMENT pass.

PiperOrigin-RevId: 227884596
  • Loading branch information
tensorflower-gardener committed Jan 4, 2019
1 parent d76c2d0 commit 735d26a
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 4 deletions.
11 changes: 7 additions & 4 deletions tensorflow/core/common_runtime/direct_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
Expand Down Expand Up @@ -303,10 +304,8 @@ DirectSession::DirectSession(const SessionOptions& options,
if (!status.ok()) {
LOG(ERROR) << status.error_message();
}
// NOTE(mrry): We do not need to use a unique string for the session
// handle, because DirectSession owns its devices. This may change
// in future versions.
session_handle_ = "direct";
session_handle_ =
strings::StrCat("direct", strings::FpToString(random::New64()));
int devices_added = 0;
if (options.config.log_device_placement()) {
const string mapping_str = device_mgr_->DeviceMappingString();
Expand Down Expand Up @@ -371,6 +370,7 @@ Status DirectSession::MaybeInitializeExecutionState(
GraphExecutionStateOptions options;
options.device_set = &device_set_;
options.session_options = &options_;
options.session_handle = session_handle_;
// TODO(mrry,suharshs): We explicitly copy `graph` so that
// `MakeForBaseGraph()` can take ownership of its
// contents. Previously this happened implicitly in calls to the
Expand Down Expand Up @@ -533,6 +533,7 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
CancellationManager step_cancellation_manager;
args.cancellation_manager = &step_cancellation_manager;
args.session_state = &session_state_;
args.session_handle = session_handle_;
args.tensor_store = &run_state.tensor_store;
args.step_container = &run_state.step_container;
args.sync_on_finish = sync_on_finish_;
Expand Down Expand Up @@ -888,6 +889,7 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
SchedClosure(pool, std::move(c));
};
args.session_state = &session_state_;
args.session_handle = session_handle_;
args.tensor_store = &run_state->tensor_store;
args.step_container = &run_state->step_container;
if (LogMemory::IsEnabled()) {
Expand Down Expand Up @@ -1465,6 +1467,7 @@ Status DirectSession::CreateGraphs(
prune_options.device_set = &device_set_;
prune_options.session_options = &options_;
prune_options.stateful_placements = stateful_placements_;
prune_options.session_handle = session_handle_;
TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph(
execution_state_->original_graph_def().library(), prune_options,
execution_state_->original_graph_def(), subgraph_options,
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/common_runtime/direct_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ class DirectSession : public Session {
std::vector<Device*> devices_; // not owned
DeviceSet device_set_;

// Unique session identifier.
string session_handle_;
mutex graph_state_lock_;
bool graph_created_ GUARDED_BY(graph_state_lock_) = false;
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/core/common_runtime/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1244,6 +1244,7 @@ class ExecutorState {
Rendezvous* rendezvous_;
CollectiveExecutor* collective_executor_ = nullptr;
SessionState* session_state_;
string session_handle_;
TensorStore* tensor_store_;
// Step-local container.
ScopedStepContainer* step_container_;
Expand Down Expand Up @@ -1371,6 +1372,7 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
rendezvous_(args.rendezvous),
collective_executor_(args.collective_executor),
session_state_(args.session_state),
session_handle_(args.session_handle),
tensor_store_(args.tensor_store),
step_container_(args.step_container),
stats_collector_(args.stats_collector),
Expand Down Expand Up @@ -1616,6 +1618,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
params.rendezvous = rendezvous_;
params.collective_executor = collective_executor_;
params.session_state = session_state_;
params.session_handle = session_handle_;
params.tensor_store = tensor_store_;
params.cancellation_manager = cancellation_manager_;
params.call_frame = call_frame_;
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/common_runtime/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class Executor {
CallFrameInterface* call_frame = nullptr;
CancellationManager* cancellation_manager = nullptr;
SessionState* session_state = nullptr;
// Unique session identifier. Can be empty.
string session_handle;
TensorStore* tensor_store = nullptr;
ScopedStepContainer* step_container = nullptr;
CollectiveExecutor* collective_executor = nullptr;
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/core/common_runtime/graph_execution_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ GraphExecutionState::GraphExecutionState(
: stateful_placements_(options.stateful_placements),
device_set_(options.device_set),
session_options_(options.session_options),
session_handle_(options.session_handle),
flib_def_(new FunctionLibraryDefinition(OpRegistry::Global(),
graph_def->library())),
graph_(nullptr) {
Expand Down Expand Up @@ -198,6 +199,7 @@ Status GraphExecutionState::Extend(
GraphExecutionStateOptions combined_options;
combined_options.device_set = device_set_;
combined_options.session_options = session_options_;
combined_options.session_handle = session_handle_;
combined_options.stateful_placements = stateful_placements_;

// NOTE(mrry): `gdef` is no longer valid after the constructor
Expand Down Expand Up @@ -558,6 +560,7 @@ Status GraphExecutionState::InitBaseGraph(const BuildGraphOptions& options) {
RestoreStatefulNodes(new_graph.get());

GraphOptimizationPassOptions optimization_options;
optimization_options.session_handle = session_handle_;
optimization_options.session_options = session_options_;
optimization_options.graph = &new_graph;
optimization_options.flib_def = flib_def_.get();
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/common_runtime/graph_execution_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ struct RewriteGraphMetadata;
struct GraphExecutionStateOptions {
const DeviceSet* device_set = nullptr;
const SessionOptions* session_options = nullptr;
// Unique session identifier. Can be empty.
string session_handle;
// A map from node name to device name, representing the unchangeable
// placement of stateful nodes.
std::unordered_map<string, string> stateful_placements;
Expand Down Expand Up @@ -192,6 +194,8 @@ class GraphExecutionState {
GraphDef original_graph_def_; // Immutable after ctor.
const DeviceSet* device_set_; // Not owned
const SessionOptions* session_options_; // Not owned
// Unique session identifier. Can be empty.
string session_handle_;

// Map from name to Node for the full graph in placed_.
NodeNameToCostIdMap node_name_to_cost_id_map_;
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/common_runtime/optimization_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct SessionOptions;
// as a key into a state dictionary if it wants to keep state across
// calls.
struct GraphOptimizationPassOptions {
// Filled in by DirectSession for PRE_PLACEMENT optimizations. Can be empty.
string session_handle;
const SessionOptions* session_options = nullptr;
const CostModel* cost_model = nullptr;
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/core/framework/op_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,9 @@ class OpKernelContext {
// The session state for this op.
SessionState* session_state = nullptr;

// Unique session identifier. Can be empty.
string session_handle;

// The tensor store for this op.
TensorStore* tensor_store = nullptr;

Expand Down Expand Up @@ -1034,6 +1037,9 @@ class OpKernelContext {
// An op kernel can access the session state it belongs to.
SessionState* session_state() const { return params_->session_state; }

// Unique identifier of the session it belongs to. Can be empty.
string session_handle() const { return params_->session_handle; }

// An op kernel can access the tensor store of the run it belongs to.
TensorStore* tensor_store() const { return params_->tensor_store; }

Expand Down

0 comments on commit 735d26a

Please sign in to comment.