Skip to content

Commit

Permalink
fix or operator
Browse files Browse the repository at this point in the history
  • Loading branch information
shentianxiao committed May 24, 2022
1 parent 8fc89b0 commit 1d1f3c9
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions cold_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def decode(model, tokenizer, device, x="", z="", constraints=None, args=None, mo
rl_nll_loss = soft_nll(
top_k_filter_3d(y_logits_rev_t / args.output_lgt_temp, args.rl_topk),
y_logits_rev[:, 1:] / args.input_lgt_temp)
elif "abductive" or "lexical" in args.mode:
elif "abductive" in args.mode or "lexical" in args.mode:
yz_logits_rev = torch.flip(torch.cat([y_logits_, z_onehot], dim=1), [1])
yz_logits_rev_t = soft_backward(model_back, yz_logits_rev / 0.00001)
yz_logits_rev_rev_t = torch.flip(yz_logits_rev_t, [1])
Expand All @@ -289,7 +289,7 @@ def decode(model, tokenizer, device, x="", z="", constraints=None, args=None, mo
ngram_list=list(range(2, args.counterfactual_max_ngram + 1))
)

if "abductive" or "lexical" in args.mode:
if "abductive" in args.mode or "lexical" in args.mode:
soft_forward_y_ = (y_logits_.detach() / 0.3 - y_logits_).detach() + y_logits_
xyz_logits, xy_length = soft_forward_xyz(model, soft_forward_x, soft_forward_y_, z_onehot)

Expand Down Expand Up @@ -328,7 +328,7 @@ def decode(model, tokenizer, device, x="", z="", constraints=None, args=None, mo
text, _, _ = decode_with_model_topk(
model, y_logits_, args.topk, soft_forward_x, x_model_past, tokenizer, extra_mask=z_mask)
for bi in range(args.batch_size):
if "abductive" or "lexical" in args.mode:
if "abductive" in args.mode or "lexical" in args.mode:
print(
"%d, loss: %.4f, lr_nll_loss: %.4f, rl_nll_loss: %.4f, c_loss_2: %.4f, lr: %.4f, |%s|" % (
iter + 1, loss.item(), lr_nll_loss[bi].item(), rl_nll_loss[bi].item(),
Expand All @@ -338,7 +338,7 @@ def decode(model, tokenizer, device, x="", z="", constraints=None, args=None, mo
print("%d, loss: %.4f, lr_nll_loss: %.4f, c_loss: %.4f, lr: %.4f, |%s|" % (
iter + 1, loss.item(), lr_nll_loss[bi].item(), c_loss[bi].item(), last_lr, text[bi]))

if "abductive" or "lexical" in args.mode:
if "abductive" in args.mode or "lexical" in args.mode:
pass

print()
Expand Down

0 comments on commit 1d1f3c9

Please sign in to comment.