Skip to content

Commit

Permalink
update gpt_with_pir.py
Browse files Browse the repository at this point in the history
  • Loading branch information
gongshaotian committed Nov 10, 2023
1 parent ff56f25 commit e9fb452
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 15 deletions.
17 changes: 3 additions & 14 deletions test/auto_parallel/gpt_with_pir.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,21 +170,10 @@ def test_dp_with_fused_linear(self):
out_dp_ir = engine_dp_ir.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
# TODO(zhiqiu): fix accuracy problem and use array_equal to check it
np.testing.assert_allclose(
out_dp_prog.history["loss"][0],
out_dp_ir.history["loss"][0],
rtol=1e-5,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__,
out_dp_prog.history["loss"][0],
out_dp_ir.history["loss"][0],
out_dp_prog.history["loss"][0] - out_dp_ir.history["loss"][0],
),

self.check_results(
out_dp_prog.history["loss"][0], out_dp_ir.history["loss"][0]
)
# self.check_results(
# out_dp_prog.history["loss"][0], out_dp_ir.history["loss"][0]
# )

def test_mp(self):
self.enable_pir(False)
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/pir/pattern_rewrite/drr_fuse_linear_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,5 +140,5 @@ TEST(DrrTest, FusedLinear) {
pm.EnableIRPrinting();

CHECK_EQ(pm.Run(&program), true);
EXPECT_EQ(program.block()->size(), 22u);
EXPECT_EQ(program.block()->size(), 23u);
}

0 comments on commit e9fb452

Please sign in to comment.