Skip to content

Commit

Permalink
8328404: RISC-V: Fix potential crash in C2_MacroAssembler::arrays_equals
Browse files Browse the repository at this point in the history
  • Loading branch information
zifeihan committed Mar 19, 2024
1 parent 9059727 commit 8824e1c
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 117 deletions.
177 changes: 87 additions & 90 deletions src/hotspot/cpu/riscv/c2_MacroAssembler_riscv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1536,93 +1536,105 @@ void C2_MacroAssembler::string_compare(Register str1, Register str2,
BLOCK_COMMENT("} string_compare");
}

void C2_MacroAssembler::arrays_equals(Register a1, Register a2, Register tmp3,
Register tmp4, Register tmp5, Register tmp6, Register result,
Register cnt1, int elem_size) {
Label DONE, SAME, NEXT_DWORD, SHORT, TAIL, TAIL2, IS_TMP5_ZR;
Register tmp1 = t0;
Register tmp2 = t1;
Register cnt2 = tmp2; // cnt2 only used in array length compare
Register elem_per_word = tmp6;
void C2_MacroAssembler::arrays_equals(Register a1, Register a2,
Register tmp1, Register tmp2, Register tmp3,
Register result, int elem_size) {
assert(elem_size == 1 || elem_size == 2, "must be char or byte");
assert_different_registers(a1, a2, result, tmp1, tmp2, tmp3, t0);

int elem_per_word = wordSize/elem_size;
int log_elem_size = exact_log2(elem_size);
int length_offset = arrayOopDesc::length_offset_in_bytes();
int base_offset = arrayOopDesc::base_offset_in_bytes(elem_size == 2 ? T_CHAR : T_BYTE);

assert(elem_size == 1 || elem_size == 2, "must be char or byte");
assert_different_registers(a1, a2, result, cnt1, t0, t1, tmp3, tmp4, tmp5, tmp6);
mv(elem_per_word, wordSize / elem_size);
Register cnt1 = tmp3;
Register cnt2 = tmp1; // cnt2 only used in array length compare
Label DONE, SAME, NEXT_WORD, SHORT, TAIL03, TAIL01;

BLOCK_COMMENT("arrays_equals {");

// if (a1 == a2), return true
beq(a1, a2, SAME);

mv(result, false);
// if (a1 == nullptr || a2 == nullptr)
// return false;
beqz(a1, DONE);
beqz(a2, DONE);

// if (a1.length != a2.length)
// return false;
lwu(cnt1, Address(a1, length_offset));
lwu(cnt2, Address(a2, length_offset));
bne(cnt2, cnt1, DONE);
beqz(cnt1, SAME);

slli(tmp5, cnt1, 3 + log_elem_size);
sub(tmp5, zr, tmp5);
add(a1, a1, base_offset);
add(a2, a2, base_offset);
ld(tmp3, Address(a1, 0));
ld(tmp4, Address(a2, 0));
ble(cnt1, elem_per_word, SHORT); // short or same

// Main 16 byte comparison loop with 2 exits
bind(NEXT_DWORD); {
ld(tmp1, Address(a1, wordSize));
ld(tmp2, Address(a2, wordSize));
sub(cnt1, cnt1, 2 * wordSize / elem_size);
blez(cnt1, TAIL);
bne(tmp3, tmp4, DONE);
ld(tmp3, Address(a1, 2 * wordSize));
ld(tmp4, Address(a2, 2 * wordSize));
add(a1, a1, 2 * wordSize);
add(a2, a2, 2 * wordSize);
ble(cnt1, elem_per_word, TAIL2);
} beq(tmp1, tmp2, NEXT_DWORD);
j(DONE);
bne(cnt1, cnt2, DONE);

bind(TAIL);
xorr(tmp4, tmp3, tmp4);
xorr(tmp2, tmp1, tmp2);
sll(tmp2, tmp2, tmp5);
orr(tmp5, tmp4, tmp2);
j(IS_TMP5_ZR);
la(a1, Address(a1, base_offset));
la(a2, Address(a2, base_offset));
// Check for short strings, i.e. smaller than wordSize.
addi(cnt1, cnt1, -elem_per_word);
bltz(cnt1, SHORT);

// Main 8 byte comparison loop.
bind(NEXT_WORD); {
ld(tmp1, Address(a1));
ld(tmp2, Address(a2));
addi(cnt1, cnt1, -elem_per_word);
addi(a1, a1, wordSize);
addi(a2, a2, wordSize);
bne(tmp1, tmp2, DONE);
} bgez(cnt1, NEXT_WORD);

bind(TAIL2);
bne(tmp1, tmp2, DONE);
addi(tmp1, cnt1, elem_per_word);
beqz(tmp1, SAME);

bind(SHORT);
xorr(tmp4, tmp3, tmp4);
sll(tmp5, tmp4, tmp5);
test_bit(tmp1, cnt1, 2 - log_elem_size);
beqz(tmp1, TAIL03); // 0-7 bytes left.
{
lwu(tmp1, Address(a1));
lwu(tmp2, Address(a2));
addi(a1, a1, 4);
addi(a2, a2, 4);
bne(tmp1, tmp2, DONE);
}

bind(IS_TMP5_ZR);
bnez(tmp5, DONE);
bind(TAIL03);
test_bit(tmp1, cnt1, 1 - log_elem_size);
beqz(tmp1, TAIL01); // 0-3 bytes left.
{
lhu(tmp1, Address(a1));
lhu(tmp2, Address(a2));
addi(a1, a1, 2);
addi(a2, a2, 2);
bne(tmp1, tmp2, DONE);
}

bind(TAIL01);
if (elem_size == 1) { // Only needed when comparing byte arrays.
test_bit(tmp1, cnt1, 0);
beqz(tmp1, SAME); // 0-1 bytes left.
{
lbu(tmp1, Address(a1));
lbu(tmp2, Address(a2));
bne(tmp1, tmp2, DONE);
}
}

bind(SAME);
mv(result, true);
// That's it.
bind(DONE);

BLOCK_COMMENT("} array_equals");
BLOCK_COMMENT("} arrays_equals");
}

// Compare Strings

// For Strings we're passed the address of the first characters in a1
// and a2 and the length in cnt1.
// There are two implementations. For arrays >= 8 bytes, all
// comparisons (for hw supporting unaligned access: including the final one,
// which may overlap) are performed 8 bytes at a time.
// For strings < 8 bytes (and for tails of long strings when
// AvoidUnalignedAccesses is true), we compare a
// halfword, then a short, and then a byte.
// For Strings we're passed the address of the first characters in a1 and a2
// and the length in cnt1. There are two implementations.
// For arrays >= 8 bytes, all comparisons (except for the tail) are performed
// 8 bytes at a time. For the tail, we compare a halfword, then a short, and then a byte.
// For strings < 8 bytes, we compare a halfword, then a short, and then a byte.

void C2_MacroAssembler::string_equals(Register a1, Register a2,
Register result, Register cnt1)
Expand All @@ -1635,39 +1647,24 @@ void C2_MacroAssembler::string_equals(Register a1, Register a2,

BLOCK_COMMENT("string_equals {");

beqz(cnt1, SAME);
mv(result, false);

// Check for short strings, i.e. smaller than wordSize.
sub(cnt1, cnt1, wordSize);
addi(cnt1, cnt1, -wordSize);
bltz(cnt1, SHORT);

// Main 8 byte comparison loop.
bind(NEXT_WORD); {
ld(tmp1, Address(a1, 0));
add(a1, a1, wordSize);
ld(tmp2, Address(a2, 0));
add(a2, a2, wordSize);
sub(cnt1, cnt1, wordSize);
ld(tmp1, Address(a1));
ld(tmp2, Address(a2));
addi(cnt1, cnt1, -wordSize);
addi(a1, a1, wordSize);
addi(a2, a2, wordSize);
bne(tmp1, tmp2, DONE);
} bgez(cnt1, NEXT_WORD);

if (!AvoidUnalignedAccesses) {
// Last longword. In the case where length == 4 we compare the
// same longword twice, but that's still faster than another
// conditional branch.
// cnt1 could be 0, -1, -2, -3, -4 for chars; -4 only happens when
// length == 4.
add(tmp1, a1, cnt1);
ld(tmp1, Address(tmp1, 0));
add(tmp2, a2, cnt1);
ld(tmp2, Address(tmp2, 0));
bne(tmp1, tmp2, DONE);
j(SAME);
} else {
add(tmp1, cnt1, wordSize);
beqz(tmp1, SAME);
}
addi(tmp1, cnt1, wordSize);
beqz(tmp1, SAME);

bind(SHORT);
Label TAIL03, TAIL01;
Expand All @@ -1676,10 +1673,10 @@ void C2_MacroAssembler::string_equals(Register a1, Register a2,
test_bit(tmp1, cnt1, 2);
beqz(tmp1, TAIL03);
{
lwu(tmp1, Address(a1, 0));
add(a1, a1, 4);
lwu(tmp2, Address(a2, 0));
add(a2, a2, 4);
lwu(tmp1, Address(a1));
lwu(tmp2, Address(a2));
addi(a1, a1, 4);
addi(a2, a2, 4);
bne(tmp1, tmp2, DONE);
}

Expand All @@ -1688,10 +1685,10 @@ void C2_MacroAssembler::string_equals(Register a1, Register a2,
test_bit(tmp1, cnt1, 1);
beqz(tmp1, TAIL01);
{
lhu(tmp1, Address(a1, 0));
add(a1, a1, 2);
lhu(tmp2, Address(a2, 0));
add(a2, a2, 2);
lhu(tmp1, Address(a1));
lhu(tmp2, Address(a2));
addi(a1, a1, 2);
addi(a2, a2, 2);
bne(tmp1, tmp2, DONE);
}

Expand All @@ -1700,8 +1697,8 @@ void C2_MacroAssembler::string_equals(Register a1, Register a2,
test_bit(tmp1, cnt1, 0);
beqz(tmp1, SAME);
{
lbu(tmp1, Address(a1, 0));
lbu(tmp2, Address(a2, 0));
lbu(tmp1, Address(a1));
lbu(tmp2, Address(a2));
bne(tmp1, tmp2, DONE);
}

Expand Down
7 changes: 3 additions & 4 deletions src/hotspot/cpu/riscv/c2_MacroAssembler_riscv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,15 @@
int needle_con_cnt, Register result, int ae);

void arrays_equals(Register r1, Register r2,
Register tmp3, Register tmp4,
Register tmp5, Register tmp6,
Register result, Register cnt1,
int elem_size);
Register tmp1, Register tmp2, Register tmp3,
Register result, int elem_size);

void arrays_hashcode(Register ary, Register cnt, Register result,
Register tmp1, Register tmp2,
Register tmp3, Register tmp4,
Register tmp5, Register tmp6,
BasicType eltype);

// helper function for arrays_hashcode
int arrays_hashcode_elsize(BasicType eltype);
void arrays_hashcode_elload(Register dst, Address src, BasicType eltype);
Expand Down
33 changes: 10 additions & 23 deletions src/hotspot/cpu/riscv/riscv.ad
Original file line number Diff line number Diff line change
Expand Up @@ -3286,17 +3286,6 @@ operand iRegP_R15()
interface(REG_INTER);
%}

operand iRegP_R16()
%{
constraint(ALLOC_IN_RC(r16_reg));
match(RegP);
// match(iRegP);
match(iRegPNoSp);
op_cost(0);
format %{ %}
interface(REG_INTER);
%}

// Pointer 64 bit Register R28 only
operand iRegP_R28()
%{
Expand Down Expand Up @@ -10336,35 +10325,33 @@ instruct string_equalsL(iRegP_R11 str1, iRegP_R13 str2, iRegI_R14 cnt,
%}

instruct array_equalsB(iRegP_R11 ary1, iRegP_R12 ary2, iRegI_R10 result,
iRegP_R13 tmp1, iRegP_R14 tmp2, iRegP_R15 tmp3,
iRegP_R16 tmp4, iRegP_R28 tmp5, rFlagsReg cr)
iRegP_R13 tmp1, iRegP_R14 tmp2, iRegP_R15 tmp3)
%{
predicate(!UseRVV && ((AryEqNode*)n)->encoding() == StrIntrinsicNode::LL);
match(Set result (AryEq ary1 ary2));
effect(USE_KILL ary1, USE_KILL ary2, TEMP tmp1, TEMP tmp2, TEMP tmp3, TEMP tmp4, KILL tmp5, KILL cr);
effect(USE_KILL ary1, USE_KILL ary2, TEMP tmp1, TEMP tmp2, TEMP tmp3);

format %{ "Array Equals $ary1, ary2 -> $result\t#@array_equalsB // KILL $tmp5" %}
format %{ "Array Equals $ary1, $ary2 -> $result\t#@array_equalsB // KILL all" %}
ins_encode %{
__ arrays_equals($ary1$$Register, $ary2$$Register,
$tmp1$$Register, $tmp2$$Register, $tmp3$$Register, $tmp4$$Register,
$result$$Register, $tmp5$$Register, 1);
$tmp1$$Register, $tmp2$$Register, $tmp3$$Register,
$result$$Register, 1);
%}
ins_pipe(pipe_class_memory);
%}

instruct array_equalsC(iRegP_R11 ary1, iRegP_R12 ary2, iRegI_R10 result,
iRegP_R13 tmp1, iRegP_R14 tmp2, iRegP_R15 tmp3,
iRegP_R16 tmp4, iRegP_R28 tmp5, rFlagsReg cr)
iRegP_R13 tmp1, iRegP_R14 tmp2, iRegP_R15 tmp3)
%{
predicate(!UseRVV && ((AryEqNode*)n)->encoding() == StrIntrinsicNode::UU);
match(Set result (AryEq ary1 ary2));
effect(USE_KILL ary1, USE_KILL ary2, TEMP tmp1, TEMP tmp2, TEMP tmp3, TEMP tmp4, KILL tmp5, KILL cr);
effect(USE_KILL ary1, USE_KILL ary2, TEMP tmp1, TEMP tmp2, TEMP tmp3);

format %{ "Array Equals $ary1, ary2 -> $result\t#@array_equalsC // KILL $tmp5" %}
format %{ "Array Equals $ary1, $ary2 -> $result\t#@array_equalsC // KILL all" %}
ins_encode %{
__ arrays_equals($ary1$$Register, $ary2$$Register,
$tmp1$$Register, $tmp2$$Register, $tmp3$$Register, $tmp4$$Register,
$result$$Register, $tmp5$$Register, 2);
$tmp1$$Register, $tmp2$$Register, $tmp3$$Register,
$result$$Register, 2);
%}
ins_pipe(pipe_class_memory);
%}
Expand Down

0 comments on commit 8824e1c

Please sign in to comment.