Skip to content

Commit

Permalink
Merge branch 'wangwang/support_tp_vllm' into xuejun/tmp-1017-v2
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaixuejun1993 authored Oct 22, 2024
2 parents 25de131 + 8402627 commit 3886ba2
Showing 1 changed file with 5 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,7 @@ TEST(TransformationTestsF1, FullyConnectedSplitInput16) {

// -------- Loading a model to the device --------
ov::Core core;
// ov::CompiledModel compiled_model = core.compile_model(model, "GPU");
ov::CompiledModel compiled_model = core.compile_model(model, "GPU", ov::device::priorities("GPU.0,GPU.1,GPU.2,GPU.3"));
// ov::CompiledModel compiled_model = core.compile_model(model, "GPU", ov::device::priorities("GPU.0,GPU.1,GPU.2,GPU.3"));
ov::CompiledModel compiled_model = core.compile_model(model, "GPU", ov::device::priorities("GPU.0,GPU.1"));

// -------- Create an infer request --------
ov::InferRequest infer_request = compiled_model.create_infer_request();
Expand Down Expand Up @@ -252,12 +250,12 @@ TEST(TransformationTestsF1, FullyConnectedSplitInput1024) {
{
// -------- Construct model
unsigned long m = 1024, k = 2048, n = 13696;
auto input1 = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{m, n});
auto input1 = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{m, k});
std::vector<float> input_data(m * k, 1);
std::vector<float> weights(k * n, 2);
std::vector<float> weights(n * k, 2);
std::vector<float> result(m * n, 2);

auto input2 = ov::op::v0::Constant::create(ov::element::f32, ov::Shape{k, n}, weights.data());
auto input2 = ov::op::v0::Constant::create(ov::element::f32, ov::Shape{n, k}, weights.data());

std::cout << "input_shape: " << m << " * " << k << std::endl;
std::cout << "weight_shape: " << k << " * " << n << std::endl;
Expand All @@ -273,7 +271,7 @@ TEST(TransformationTestsF1, FullyConnectedSplitInput1024) {

// -------- Loading a model to the device --------
ov::Core core;
ov::CompiledModel compiled_model = core.compile_model(model, "GPU");
ov::CompiledModel compiled_model = core.compile_model(model, "GPU", ov::device::priorities("GPU.0,GPU.1"));

// -------- Create an infer request --------
ov::InferRequest infer_request = compiled_model.create_infer_request();
Expand Down

0 comments on commit 3886ba2

Please sign in to comment.