-
Notifications
You must be signed in to change notification settings - Fork 511
/
Copy pathpjrt_c_api_client.h
847 lines (660 loc) · 30.6 KB
/
pjrt_c_api_client.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
/* Copyright 2022 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_PJRT_PJRT_C_API_CLIENT_H_
#define XLA_PJRT_PJRT_C_API_CLIENT_H_
#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <variant>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/functional/any_invocable.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "xla/client/xla_computation.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/layout.h"
#include "xla/literal.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_common.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/pjrt_device_description.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/pjrt_future.h"
#include "xla/pjrt/pjrt_layout.h"
#include "xla/service/computation_placer.h"
#include "xla/service/hlo_cost_analysis.h"
#include "xla/shape.h"
#include "xla/status.h"
#include "xla/statusor.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/framework/allocator.h"
namespace xla {
class PjRtCApiClient;
class PjRtCApiDeviceDescription : public PjRtDeviceDescription {
public:
PjRtCApiDeviceDescription(const PJRT_Api* c_api,
PJRT_DeviceDescription* device_description);
int id() const override;
int process_index() const override;
absl::string_view device_kind() const override;
absl::string_view DebugString() const override;
absl::string_view ToString() const override;
const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
const override;
private:
const PJRT_Api* c_api_;
// `device_description_` is owned by the `PJRT_Client` wrapped by `client_`
PJRT_DeviceDescription* device_description_;
// Device specific attributes with corresponding values.
absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute> attributes_;
// Initializes device specific attributes.
void InitAttributes();
};
class PjRtCApiMemorySpace : public PjRtMemorySpace {
public:
explicit PjRtCApiMemorySpace(PJRT_Memory* c_memory, PjRtCApiClient* client)
: client_(client), c_memory_(c_memory) {}
PjRtClient* client() const override;
absl::Span<PjRtDevice* const> devices() const override { return devices_; }
int id() const override;
absl::string_view kind() const override;
int kind_id() const override;
absl::string_view DebugString() const override;
absl::string_view ToString() const override;
const PJRT_Api* pjrt_c_api() const;
PJRT_Memory* c_memory() const { return c_memory_; }
private:
friend class PjRtCApiClient;
PjRtCApiClient* client_;
PJRT_Memory* c_memory_;
std::vector<PjRtDevice*> devices_;
};
class PjRtCApiDevice : public PjRtDevice {
public:
explicit PjRtCApiDevice(PJRT_Device* device, PjRtCApiClient* client);
PjRtClient* client() const override;
bool IsAddressable() const override;
int local_hardware_id() const override;
PjRtLocalHardwareId local_hardware_id_typed() const override;
Status TransferToInfeed(const LiteralSlice& literal) override {
return Unimplemented(
"PJRT C API does not support TransferToInfeed. Please report an issue "
"at https://github.com/google/jax/issues if you need this feature.");
}
Status TransferFromOutfeed(MutableBorrowingLiteral literal) override {
return Unimplemented(
"PJRT C API does not support TransferFromOutfeed. Please report an "
"issue at https://github.com/google/jax/issues if you need this "
"feature.");
}
absl::Span<PjRtMemorySpace* const> memory_spaces() const override {
return memory_spaces_;
}
StatusOr<PjRtMemorySpace*> default_memory_space() const override;
std::unique_ptr<ScopedAsyncTrackingEvent> CreateAsyncTrackingEvent(
absl::string_view description) const override {
LOG(FATAL)
<< "PJRT C API does not support CreateAsyncTrackingEvent. Please "
"report an issue at https://github.com/google/jax/issues if you "
"need this feature.";
return nullptr;
}
PJRT_Device* c_device() const { return device_; }
const PjRtCApiDeviceDescription& description() const override {
return description_;
}
StatusOr<tsl::AllocatorStats> GetAllocatorStats() const override;
absl::StatusOr<std::intptr_t> GetStreamForExternalReadyEvents()
const override;
private:
friend class PjRtCApiClient;
PjRtCApiClient* client_ = nullptr;
// `device_` is owned by the `PJRT_Client` wrapped by `client_`
PJRT_Device* device_;
PjRtCApiDeviceDescription description_;
std::vector<PjRtMemorySpace*> memory_spaces_;
};
class PjRtCApiCompiler : public PjRtCompiler {
public:
explicit PjRtCApiCompiler(const PJRT_Api* c_api) : c_api_(c_api) {}
StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
CompileOptions options, const XlaComputation& computation,
const PjRtTopologyDescription& topology, PjRtClient* client) override;
StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
CompileOptions options, mlir::ModuleOp module,
const PjRtTopologyDescription& topology, PjRtClient* client) override;
private:
const PJRT_Api* c_api_;
};
class PjRtCApiTopologyDescription : public PjRtTopologyDescription {
public:
// `owned` indicates whether this PjRtCApiTopologyDescription should take
// ownership of `c_topology`, i.e., if owned is true,
// PJRT_TopologyDescription_Destroy will be called on `c_topology` when this
// PjRtCApiTopologyDescription is destroyed.
PjRtCApiTopologyDescription(const PJRT_Api* c_api,
PJRT_TopologyDescription* c_topology, bool owned);
PjRtPlatformId platform_id() const override {
CHECK(false) << "PJRT C API does not support platform_id.";
}
absl::string_view platform_name() const override;
absl::string_view platform_version() const override;
std::optional<PjRtCompiler*> compiler() const override {
return compiler_.get();
}
PJRT_TopologyDescription* c_topology() const { return c_topology_; }
std::vector<std::unique_ptr<const PjRtDeviceDescription>> DeviceDescriptions()
const override;
absl::StatusOr<std::string> Serialize() const override;
// Returns vendor specific attributes about the topology.
const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
const override {
return attributes_;
}
StatusOr<Layout> GetDefaultLayout(
PrimitiveType element_type,
absl::Span<const int64_t> dims) const override {
return Unimplemented("PJRT C API does not support GetDefaultLayout");
}
private:
std::unique_ptr<PjRtCApiCompiler> compiler_;
const PJRT_Api* c_api_;
// nullptr iff the PJRT_TopologyDescription isn't owned by this wrapper
// (i.e. by the caller).
std::unique_ptr<PJRT_TopologyDescription,
::pjrt::PJRT_TopologyDescriptionDeleter>
owned_c_topology_;
PJRT_TopologyDescription* c_topology_;
// Device specific attributes with corresponding values.
absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute> attributes_;
// Initializes device specific attributes.
void InitAttributes();
};
class PjRtCApiClient : public PjRtClient {
public:
PjRtCApiClient(
const PJRT_Api* c_api, PJRT_Client* c_client,
std::unique_ptr<::pjrt::PJRT_KeyValueCallbackData> kv_callback_data);
int process_index() const override;
int device_count() const override;
int addressable_device_count() const override;
absl::Span<PjRtDevice* const> devices() const override;
absl::Span<PjRtDevice* const> addressable_devices() const override;
StatusOr<PjRtDevice*> LookupDevice(
PjRtGlobalDeviceId global_device_id) const override;
StatusOr<PjRtDevice*> LookupAddressableDevice(
int local_hardware_id) const override;
StatusOr<PjRtDevice*> LookupAddressableDevice(
PjRtLocalDeviceId local_device_id) const override;
absl::Span<PjRtMemorySpace* const> memory_spaces() const override;
PjRtPlatformId platform_id() const override { return platform_id_; }
absl::string_view platform_name() const override { return platform_name_; };
absl::string_view platform_version() const override;
std::optional<PjRtPluginAttributes> plugin_attributes() const override;
// TODO(b/244756954): Rethink this function altogether
PjRtRuntimeType runtime_type() const override {
return PjRtRuntimeType::kTfrt;
}
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override;
StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis()
const override {
return Unimplemented(
"PJRT C API does not support GetHloCostAnalysis. Please report an "
"issue at https://github.com/google/jax/issues if you need this "
"feature.");
}
StatusOr<Layout> GetDefaultLayout(PrimitiveType element_type,
absl::Span<const int64_t> dims) override;
StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
const XlaComputation& computation, CompileOptions options) override;
StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
mlir::ModuleOp module, CompileOptions options) override;
// `PjRtCApiClient::DeserializeExecutable()` ignores `CompileOptions` arg
StatusOr<std::unique_ptr<PjRtLoadedExecutable>> DeserializeExecutable(
absl::string_view serialized,
std::optional<CompileOptions> options) override;
StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device) override {
return Unimplemented(
"PJRT C API does not support CreateUninitializedBuffer. Please report "
"an issue at https://github.com/google/jax/issues if you need this "
"feature.");
}
StatusOr<const PjRtTopologyDescription*> GetTopologyDescription()
const override;
StatusOr<std::unique_ptr<AsyncHostToDeviceTransferManager>>
CreateBuffersForAsyncHostToDevice(absl::Span<const Shape> shapes,
PjRtDevice* device) override {
return Unimplemented(
"PJRT C API does not support CreateBuffersForAsyncHostToDevice. Please "
"report an issue at https://github.com/google/jax/issues if you need "
"this feature.");
}
absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
CreateBuffersForAsyncHostToDevice(absl::Span<const Shape> shapes,
PjRtMemorySpace* memory_space) override {
return Unimplemented(
"PJRT C API does not support CreateBuffersForAsyncHostToDevice. Please "
"report an issue at https://github.com/google/jax/issues if you need "
"this feature.");
}
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
std::optional<absl::Span<int64_t const>> byte_strides,
HostBufferSemantics host_buffer_semantics,
absl::AnyInvocable<void() &&> on_done_with_host_buffer,
PjRtDevice* device) override;
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
std::optional<absl::Span<int64_t const>> byte_strides,
HostBufferSemantics host_buffer_semantics,
absl::AnyInvocable<void() &&> on_done_with_host_buffer,
PjRtDevice* device, const Layout* device_layout) override;
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
std::optional<absl::Span<int64_t const>> byte_strides,
HostBufferSemantics host_buffer_semantics,
absl::AnyInvocable<void() &&> on_done_with_host_buffer,
PjRtMemorySpace* memory_space, const Layout* device_layout) override;
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
const LiteralSlice& literal, PjRtDevice* device) override {
return Unimplemented(
"PJRT C API does not support BufferFromHostLiteral. Please report an "
"issue at https://github.com/google/jax/issues if you need this "
"feature.");
}
StatusOr<std::unique_ptr<PjRtBuffer>> CreateViewOfDeviceBuffer(
void* device_ptr, const Shape& shape, PjRtDevice* device,
std::function<void()> on_delete_callback,
std::optional<std::intptr_t> stream) override;
StatusOr<std::uintptr_t> UnsafeBufferPointer(PjRtBuffer* buffer) override;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,
PjRtDevice* device,
PjRtCrossHostRecvNotifier notifier) override {
return Unimplemented(
"PJRT C API does not support MakeCrossHostReceiveBuffers. Please "
"report an issue at https://github.com/google/jax/issues if you need "
"this feature.");
}
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
MakeCrossHostReceiveBuffersForGather(
absl::Span<const Shape> shapes, std::vector<GatherDetails> gather_details,
PjRtDevice* device, PjRtCrossHostRecvNotifier notifier) override {
return Unimplemented(
"PJRT C API does not support MakeCrossHostReceiveBuffers. Please "
"report an issue at https://github.com/google/jax/issues if you need "
"this feature.");
}
StatusOr<ChannelHandle> CreateChannelHandle() override {
return Unimplemented(
"PJRT C API does not support CreateChannelHandle. Please report an "
"issue at https://github.com/google/jax/issues if you need this "
"feature.");
}
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override {
return Unimplemented(
"PJRT C API does not support CreateDeviceToHostChannelHandle. Please "
"report an issue at https://github.com/google/jax/issues if you need "
"this feature.");
}
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override {
return Unimplemented(
"PJRT C API does not support CreateHostToDeviceChannelHandle. Please "
"report an issue at https://github.com/google/jax/issues if you need "
"this feature.");
}
Status Defragment() override {
return Unimplemented(
"PJRT C API does not support Defragment. Please report an issue at "
"https://github.com/google/jax/issues if you need this feature.");
}
bool SupportsSendRecvCallbacks() const override { return true; }
const PJRT_Api* pjrt_c_api() const;
PJRT_Client* pjrt_c_client() { return c_client_.get(); }
PjRtCApiDevice* GetCppDevice(PJRT_Device* c_device) const {
auto it = c_to_cpp_device_map_.find(c_device);
CHECK(it != c_to_cpp_device_map_.end());
return it->second;
}
PjRtCApiMemorySpace* GetCppMemory(PJRT_Memory* c_memory) const {
auto it = c_to_cpp_memory_map_.find(c_memory);
CHECK(it != c_to_cpp_memory_map_.end());
return it->second;
}
PjRtHostMemoryForDeviceManager* GetPjRtHostMemoryForDeviceManager()
const override {
return nullptr;
}
private:
void InitDevicesAndMemorySpaces();
void InitAttributes();
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBufferInternalImpl(
const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
std::optional<absl::Span<int64_t const>> byte_strides,
HostBufferSemantics host_buffer_semantics,
absl::AnyInvocable<void() &&> on_done_with_host_buffer,
std::variant<PjRtDevice*, PjRtMemorySpace*> device_or_memory,
const Layout* device_layout);
const PJRT_Api* c_api_;
std::unique_ptr<PJRT_Client, ::pjrt::PJRT_ClientDeleter> c_client_;
std::unique_ptr<::pjrt::PJRT_KeyValueCallbackData> kv_callback_data_;
std::vector<std::unique_ptr<PjRtCApiDevice>> owned_devices_;
std::vector<PjRtDevice*> devices_;
std::vector<PjRtDevice*> addressable_devices_;
absl::flat_hash_map<PJRT_Device*, PjRtCApiDevice*> c_to_cpp_device_map_;
std::vector<std::unique_ptr<PjRtCApiMemorySpace>> owned_memory_spaces_;
// TODO(yueshengys): Add a `memory_spaces_` member when global memories are
// supported.
std::vector<PjRtMemorySpace*> addressable_memory_spaces_;
absl::flat_hash_map<PJRT_Memory*, PjRtCApiMemorySpace*> c_to_cpp_memory_map_;
// There may be an error fetching the topology desc via the C API
// (e.g. unimplemented). Save the error during client init so we can return it
// from GetTopologyDescription().
StatusOr<const PjRtCApiTopologyDescription> topo_desc_;
const std::string platform_version_;
const std::string platform_name_;
const PjRtPlatformId platform_id_;
absl::flat_hash_map<std::string, xla::PjRtValueType> attributes_;
};
class PjRtCApiBuffer : public PjRtBuffer {
public:
PjRtCApiBuffer(PjRtCApiClient* client, PJRT_Buffer* buffer);
PrimitiveType element_type() const override;
absl::Span<const int64_t> dimensions() const override;
std::unique_ptr<PjRtLayout> layout() const override;
// PJRT C API doesn't support tuple buffers.
bool IsTuple() const override { return false; }
const Shape& on_device_shape() const override {
LOG(FATAL) << "PjRtBuffer::on_device_shape() not implemented in PJRT C API";
}
bool has_dynamic_dimensions() const override;
absl::Span<const bool> is_dynamic_dimension() const override;
StatusOr<std::vector<int64_t>> logical_dimensions() override;
StatusOr<Shape> logical_on_device_shape() override {
LOG(FATAL) << "PjRtBuffer::on_logical_device_shape() not implemented in "
"PJRT C API";
}
PjRtMemorySpace* memory_space() const override;
PjRtDevice* device() const override;
PjRtClient* client() const override { return client_; }
StatusOr<std::unique_ptr<ExternalReference>> AcquireExternalReference()
override;
PjRtFuture<> ToLiteral(MutableLiteralBase* literal) override;
PjRtFuture<> LazyToLiteral(
absl::AnyInvocable<absl::StatusOr<MutableLiteralBase*>() &&> generator)
override;
StatusOr<size_t> GetOnDeviceSizeInBytes() const override;
PjRtFuture<> CopyRawToHost(void* dst, int64_t offset,
int64_t transfer_size) override {
return PjRtFuture<>(Unimplemented(
"PJRT C API does not support CopyRawToHost. Please report an issue at "
"https://github.com/google/jax/issues if you need this feature."));
}
void Delete() override;
StatusOr<std::unique_ptr<ExternalReference>> ReleaseDeviceMemoryOwnership(
bool wait_for_operations_to_complete) override {
return Unimplemented(
"PJRT C API does not support ReleaseDeviceMemoryOwnership");
}
bool IsDeleted() override;
StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
PjRtDevice* dst_device) override;
StatusOr<std::unique_ptr<PjRtBuffer>> CopyToMemorySpace(
PjRtMemorySpace* dst_memory_space) override;
void CopyToRemoteDevice(PjRtFuture<std::string> serialized_descriptor,
RemoteSendCallback on_done) override {
LOG(ERROR) << "PJRT C API does not support CopyToRemoteDevice. Please "
"report an issue at https://github.com/google/jax/issues if "
"you need this feature.";
}
void CopyToRemoteDeviceScattered(
PjRtFuture<std::vector<std::string>> serialized_descriptors,
std::vector<RemoteSendCallback> callbacks,
const ScatterDetails& scatter_details) override {
LOG(ERROR)
<< "PJRT C API does not support CopyToRemoteDeviceScattered. Please "
"report an issue at https://github.com/google/jax/issues if you "
"need this feature.";
}
PjRtFuture<> GetReadyFuture() override;
bool IsOnCpu() const override;
PJRT_Buffer* c_buffer() const { return buffer_.get(); }
const PJRT_Api* pjrt_c_api() const { return client_->pjrt_c_api(); }
private:
// Gets the raw pointer to `readiness_event_`. If `readiness_event_` has not
// yet been initialized, this function does so before returning the pointer.
PJRT_Event* GetReadyEvent();
// `MakePromiseTrackEvent` sets `readiness_promise_` up to track
// `readiness_event_`. This is used to implement `GetReadyFuture()`.
// `readiness_promise_` should be created before calling this function.
void MakePromiseTrackEvent();
PjRtCApiClient* client_;
std::unique_ptr<PJRT_Buffer, ::pjrt::PJRT_BufferDeleter> buffer_;
std::unique_ptr<PJRT_Event, ::pjrt::PJRT_EventDeleter> readiness_event_;
// This is a shared_ptr to keep the underlying future alive even if
// `readiness_promise` is destroyed before `readiness_event`, and the callback
// we set on `readiness_event` modifies `readiness_promise_`.
std::shared_ptr<PjRtFuture<>::Promise> readiness_promise_;
// Set and cached the first time layout() is called.
mutable std::optional<PjRtXlaLayout> layout_;
// Set and cached the first time is_dynamic_dimension() is called.
mutable std::optional<absl::InlinedVector<bool, InlineRank()>>
is_dynamic_dimension_;
// Used to synchronize concurrent setting of cached values.
mutable absl::Mutex mu_;
};
class PjRtCApiExternalReference : public PjRtBuffer::ExternalReference {
public:
PjRtCApiExternalReference(PjRtCApiClient* client, PjRtCApiBuffer* buffer,
void* data_ptr)
: client_(client), buffer_(buffer) {
data_ptr_ = data_ptr;
}
~PjRtCApiExternalReference() override;
Status WaitUntilBufferReadyOnStream(std::intptr_t stream) override;
private:
PjRtCApiClient* client_;
PjRtCApiBuffer* buffer_;
};
class PjRtCApiExecutable : public PjRtExecutable {
public:
PjRtCApiExecutable(const PJRT_Api* c_api, PJRT_Executable* executable);
absl::string_view name() const override;
int num_replicas() const override;
int num_partitions() const override;
int64_t SizeOfGeneratedCodeInBytes() const override;
StatusOr<absl::flat_hash_map<std::string, PjRtValueType>> GetCostAnalysis()
const override;
StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
const override;
StatusOr<CompiledMemoryStats> GetCompiledMemoryStats() const override {
return pjrt::GetCompiledMemoryStats(c_api_, executable_.get());
}
StatusOr<std::vector<Shape>> GetOutputShapes() const override {
LOG(FATAL) << "PjRtExecutable::GetOutputShapes() not implemented in PJRT C "
"API. Please use PjRtExecutable::GetOutputElementTypes() or "
"PjRtExecutable::GetOutputDimensions().";
}
StatusOr<std::vector<std::vector<PrimitiveType>>> GetOutputElementTypes()
const override;
StatusOr<std::vector<std::vector<DimensionVector>>> GetOutputDimensions()
const override;
StatusOr<std::vector<std::vector<absl::string_view>>> GetOutputMemoryKinds()
const override;
const PJRT_Api* pjrt_c_api() const { return c_api_; }
PJRT_Executable* c_executable() const { return executable_.get(); }
StatusOr<std::string> SerializeExecutable() const override;
StatusOr<std::string> FingerprintExecutable() const override;
private:
const PJRT_Api* c_api_;
std::unique_ptr<PJRT_Executable, ::pjrt::PJRT_ExecutableDeleter> executable_;
};
class PjRtCApiLoadedExecutable : public PjRtLoadedExecutable {
public:
PjRtCApiLoadedExecutable(PjRtCApiClient* client,
PJRT_LoadedExecutable* executable);
PjRtClient* client() const override { return client_; }
absl::string_view name() const override { return executable_->name(); }
int num_replicas() const override { return executable_->num_replicas(); }
int num_partitions() const override { return executable_->num_partitions(); }
int64_t SizeOfGeneratedCodeInBytes() const override {
return executable_->SizeOfGeneratedCodeInBytes();
}
StatusOr<absl::flat_hash_map<std::string, PjRtValueType>> GetCostAnalysis()
const override {
return executable_->GetCostAnalysis();
}
const DeviceAssignment& device_assignment() const override {
CHECK(false) << "PJRT C API does not support device_assignment";
}
absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
const override {
CHECK(false)
<< "PJRT C API does not support addressable_device_logical_ids";
}
absl::Span<PjRtDevice* const> addressable_devices() const override {
return addressable_devices_;
}
StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
const override {
return executable_->GetHloModules();
}
StatusOr<CompiledMemoryStats> GetCompiledMemoryStats() const override {
return executable_->GetCompiledMemoryStats();
}
StatusOr<std::vector<Shape>> GetOutputShapes() const override {
LOG(FATAL)
<< "PjRtLoadedExecutable::GetOutputShapes() not implemented in PJRT C "
"API. Please use PjRtLoadedExecutable::GetOutputElementTypes() or "
"PjRtLoadedExecutable::GetOutputDimensions().";
}
StatusOr<std::vector<std::vector<PrimitiveType>>> GetOutputElementTypes()
const override {
return executable_->GetOutputElementTypes();
}
StatusOr<std::vector<std::vector<DimensionVector>>> GetOutputDimensions()
const override {
return executable_->GetOutputDimensions();
}
StatusOr<std::vector<std::vector<absl::string_view>>> GetOutputMemoryKinds()
const override {
return executable_->GetOutputMemoryKinds();
}
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options,
std::optional<std::vector<PjRtFuture<>>>& returned_futures) override;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options,
std::optional<PjRtFuture<>>& returned_future, bool fill_future) override;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options,
std::optional<PjRtFuture<>>& returned_future, bool fill_future) override;
void Delete() override;
bool IsDeleted() override;
StatusOr<std::string> SerializeExecutable() const override {
return executable_->SerializeExecutable();
}
const PJRT_Api* pjrt_c_api() const { return client_->pjrt_c_api(); }
PJRT_Executable* c_executable() const { return executable_->c_executable(); }
PJRT_LoadedExecutable* c_loaded_executable() const {
return loaded_executable_.get();
}
// True if the `returned_futures` output parameter is supported in the
// Execute*() methods.
bool IsReturnedFutureSupported() const override { return true; }
// std::function version of PJRT_SendCallback
using SendCallbackFunction = std::function<PJRT_Error*(
PJRT_Chunk*, PJRT_CallbackError*, size_t, bool)>;
// std::function version of PJRT_RecvCallback
using RecvCallbackFunction = std::function<void(PJRT_CopyToDeviceStream*)>;
// Override to call FingerprintExecutable through the wrapped
// PjRtCApiExecutable.
StatusOr<std::string> FingerprintExecutable() const override;
private:
// Groups data needed to support send/recv execution callbacks.
struct SendRecvCallbackData {
std::vector<std::vector<PJRT_SendCallbackInfo>> c_send_callbacks;
std::vector<PJRT_SendCallbackInfo*> c_send_callback_lists;
std::vector<std::vector<PJRT_RecvCallbackInfo>> c_recv_callbacks;
std::vector<PJRT_RecvCallbackInfo*> c_recv_callback_lists;
std::vector<SendCallbackFunction> send_callback_functions;
std::vector<RecvCallbackFunction> recv_callback_functions;
};
// Gets common Execute_Args between Execute, ExecuteSharded and
// ExecutePortable. device_complete_events in the return is set if the input
// device_complete_events has value.
absl::StatusOr<PJRT_LoadedExecutable_Execute_Args> GetCommonExecuteArgs(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options, PJRT_ExecuteOptions& c_options,
std::vector<std::vector<PJRT_Buffer*>>& c_argument_lists_storage,
std::vector<PJRT_Buffer**>& c_arguments,
std::vector<std::vector<PJRT_Buffer*>>& c_output_lists_storage,
std::vector<PJRT_Buffer**>& c_output_lists,
std::optional<std::vector<PJRT_Event*>>& device_complete_events,
SendRecvCallbackData& send_recv_callback_data,
std::vector<int64_t>& non_donatable_input_indices_storage);
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteWithSingleDevice(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options,
std::optional<PjRtFuture<>>& returned_future, bool fill_future);
PjRtCApiClient* client_;
std::unique_ptr<PJRT_LoadedExecutable, ::pjrt::PJRT_LoadedExecutableDeleter>
loaded_executable_;
std::unique_ptr<PjRtCApiExecutable> executable_;
std::vector<PjRtDevice*> addressable_devices_;
void InitDevices();
};
class CApiCopyToDeviceStream : public CopyToDeviceStream {
public:
CApiCopyToDeviceStream(PJRT_CopyToDeviceStream* c_stream,
const PJRT_Api* c_api);
~CApiCopyToDeviceStream() override;
PjRtFuture<> AddChunk(PjRtChunk chunk) override;
private:
PJRT_CopyToDeviceStream* c_stream_;
const PJRT_Api* c_api_;
};
StatusOr<std::unique_ptr<PjRtClient>> GetCApiClient(
absl::string_view device_type,
const absl::flat_hash_map<std::string, PjRtValueType>& create_options = {},
std::shared_ptr<KeyValueStoreInterface> kv_store = nullptr);
absl::StatusOr<std::unique_ptr<PjRtTopologyDescription>> GetCApiTopology(
const PJRT_Api* c_api, absl::string_view topology_name,
const absl::flat_hash_map<std::string, PjRtValueType>& create_options);
// A variant that takes `device_type` as an input, used for plugins that are not
// registered with standard way (xla_bridge.register_plugin).
// TODO(b/322357665): Delete this method after TPU plugin changes to use the
// standard registration.
StatusOr<std::unique_ptr<PjRtTopologyDescription>> GetCApiTopology(
absl::string_view device_type, absl::string_view topology_name,
const absl::flat_hash_map<std::string, PjRtValueType>& create_options = {});
} // namespace xla
#endif // XLA_PJRT_PJRT_C_API_CLIENT_H_