Skip to content

Commit

Permalink
Handles gpu/cpu transfer in QPFunction's backward + replaces torch.Te…
Browse files Browse the repository at this point in the history
…nsor by torch.empty or torch.tensor
  • Loading branch information
oumayb authored and fabinsch committed Jan 19, 2024
1 parent 671e9b6 commit 2cd1521
Showing 1 changed file with 48 additions and 47 deletions.
95 changes: 48 additions & 47 deletions bindings/python/proxsuite/torch/qplayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,41 +50,41 @@ def QPFunction(
Solve the QP problem.
Args:
Q (torch.Tensor): Batch of quadratic cost matrices of size (nBatch, n, n) or (n, n).
p (torch.Tensor): Batch of linear cost vectors of size (nBatch, n) or (n).
A (torch.Tensor, optional): Batch of eq. constraint matrices of size (nBatch, p, n) or (p, n).
b (torch.Tensor, optional): Batch of eq. constraint vectors of size (nBatch, p) or (p).
G (torch.Tensor): Batch of ineq. constraint matrices of size (nBatch, m, n) or (m, n).
l (torch.Tensor): Batch of ineq. lower bound vectors of size (nBatch, m) or (m).
u (torch.Tensor): Batch of ineq. upper bound vectors of size (nBatch, m) or (m).
Q (torch.tensor): Batch of quadratic cost matrices of size (nBatch, n, n) or (n, n).
p (torch.tensor): Batch of linear cost vectors of size (nBatch, n) or (n).
A (torch.tensor, optional): Batch of eq. constraint matrices of size (nBatch, p, n) or (p, n).
b (torch.tensor, optional): Batch of eq. constraint vectors of size (nBatch, p) or (p).
G (torch.tensor): Batch of ineq. constraint matrices of size (nBatch, m, n) or (m, n).
l (torch.tensor): Batch of ineq. lower bound vectors of size (nBatch, m) or (m).
u (torch.tensor): Batch of ineq. upper bound vectors of size (nBatch, m) or (m).
Returns:
zhats (torch.Tensor): Batch of optimal primal solutions of size (nBatch, n).
lams (torch.Tensor): Batch of dual variables for eq. constraint of size (nBatch, m).
nus (torch.Tensor): Batch of dual variables for ineq. constraints of size (nBatch, p).
zhats (torch.tensor): Batch of optimal primal solutions of size (nBatch, n).
lams (torch.tensor): Batch of dual variables for eq. constraint of size (nBatch, m).
nus (torch.tensor): Batch of dual variables for ineq. constraints of size (nBatch, p).
Only for infeasible case:
s_e (torch.Tensor): Batch of slack variables for eq. constraints of size (nBatch, m).
s_i (torch.Tensor): Batch of slack variables for ineq. constraints of size (nBatch, p).
s_e (torch.tensor): Batch of slack variables for eq. constraints of size (nBatch, m).
s_i (torch.tensor): Batch of slack variables for ineq. constraints of size (nBatch, p).
Backward:
Compute the gradients of the QP problem wrt its parameters.
Args:
dl_dzhat (torch.Tensor): Batch of gradients of size (nBatch, n).
dl_dlams (torch.Tensor, optional): Batch of gradients of size (nBatch, p).
dl_dnus (torch.Tensor, optional): Batch of gradients of size (nBatch, m).
dl_dzhat (torch.tensor): Batch of gradients of size (nBatch, n).
dl_dlams (torch.tensor, optional): Batch of gradients of size (nBatch, p).
dl_dnus (torch.tensor, optional): Batch of gradients of size (nBatch, m).
Only for infeasible case:
dl_ds_e (torch.Tensor, optional): Batch of gradients of size (nBatch, m).
dl_ds_i (torch.Tensor, optional): Batch of gradients of size (nBatch, m).
dl_ds_e (torch.tensor, optional): Batch of gradients of size (nBatch, m).
dl_ds_i (torch.tensor, optional): Batch of gradients of size (nBatch, m).
Returns:
dQs (torch.Tensor): Batch of gradients of size (nBatch, n, n).
dps (torch.Tensor): Batch of gradients of size (nBatch, n).
dAs (torch.Tensor): Batch of gradients of size (nBatch, p, n).
dbs (torch.Tensor): Batch of gradients of size (nBatch, p).
dGs (torch.Tensor): Batch of gradients of size (nBatch, m, n).
dls (torch.Tensor): Batch of gradients of size (nBatch, m).
dus (torch.Tensor): Batch of gradients of size (nBatch, m).
dQs (torch.tensor): Batch of gradients of size (nBatch, n, n).
dps (torch.tensor): Batch of gradients of size (nBatch, n).
dAs (torch.tensor): Batch of gradients of size (nBatch, p, n).
dbs (torch.tensor): Batch of gradients of size (nBatch, p).
dGs (torch.tensor): Batch of gradients of size (nBatch, m, n).
dls (torch.tensor): Batch of gradients of size (nBatch, m).
dus (torch.tensor): Batch of gradients of size (nBatch, m).
"""
global proxqp_parallel
proxqp_parallel = omp_parallel
Expand Down Expand Up @@ -114,9 +114,9 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
if ctx.cpu is not None:
ctx.cpu = max(1, int(ctx.cpu / 2))

zhats = torch.Tensor(nBatch, ctx.nz).type_as(Q)
lams = torch.Tensor(nBatch, ctx.neq).type_as(Q)
nus = torch.Tensor(nBatch, ctx.nineq).type_as(Q)
zhats = torch.empty((nBatch, ctx.nz)).type_as(Q)
lams = torch.empty((nBatch, ctx.neq)).type_as(Q)
nus = torch.empty((nBatch, ctx.nineq)).type_as(Q)

for i in range(nBatch):
qp = ctx.vector_of_qps.init_qp_in_place(ctx.nz, ctx.neq, ctx.nineq)
Expand Down Expand Up @@ -163,22 +163,23 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
ctx.vector_of_qps.get(i).solve()

for i in range(nBatch):
zhats[i] = torch.Tensor(ctx.vector_of_qps.get(i).results.x)
lams[i] = torch.Tensor(ctx.vector_of_qps.get(i).results.y)
nus[i] = torch.Tensor(ctx.vector_of_qps.get(i).results.z)
zhats[i] = torch.tensor(ctx.vector_of_qps.get(i).results.x)
lams[i] = torch.tensor(ctx.vector_of_qps.get(i).results.y)
nus[i] = torch.tensor(ctx.vector_of_qps.get(i).results.z)

return zhats, lams, nus

@staticmethod
def backward(ctx, dl_dzhat, dl_dlams, dl_dnus):
device = dl_dzhat.device
nBatch, dim, neq, nineq = ctx.nBatch, ctx.nz, ctx.neq, ctx.nineq
dQs = torch.Tensor(nBatch, ctx.nz, ctx.nz)
dps = torch.Tensor(nBatch, ctx.nz)
dGs = torch.Tensor(nBatch, ctx.nineq, ctx.nz)
dus = torch.Tensor(nBatch, ctx.nineq)
dls = torch.Tensor(nBatch, ctx.nineq)
dAs = torch.Tensor(nBatch, ctx.neq, ctx.nz)
dbs = torch.Tensor(nBatch, ctx.neq)
dQs = torch.empty(nBatch, ctx.nz, ctx.nz, device=device)
dps = torch.empty(nBatch, ctx.nz, device=device)
dGs = torch.empty(nBatch, ctx.nineq, ctx.nz, device=device)
dus = torch.empty(nBatch, ctx.nineq, device=device)
dls = torch.empty(nBatch, ctx.nineq, device=device)
dAs = torch.empty(nBatch, ctx.neq, ctx.nz, device=device)
dbs = torch.empty(nBatch, ctx.neq, device=device)

ctx.cpu = os.cpu_count()
if ctx.cpu is not None:
Expand Down Expand Up @@ -211,11 +212,11 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus):
else:
for i in range(nBatch):
rhs = np.zeros(n_tot)
rhs[:dim] = dl_dzhat[i]
rhs[:dim] = dl_dzhat[i].cpu()
if dl_dlams != None:
rhs[dim : dim + neq] = dl_dlams[i]
rhs[dim : dim + neq] = dl_dlams[i].cpu()
if dl_dnus != None:
rhs[dim + neq :] = dl_dnus[i]
rhs[dim + neq :] = dl_dnus[i].cpu()
qpi = ctx.vector_of_qps.get(i)
proxsuite.proxqp.dense.compute_backward(
qp=qpi,
Expand All @@ -226,25 +227,25 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus):
)

for i in range(nBatch):
dQs[i] = torch.Tensor(
dQs[i] = torch.tensor(
ctx.vector_of_qps.get(i).model.backward_data.dL_dH
)
dps[i] = torch.Tensor(
dps[i] = torch.tensor(
ctx.vector_of_qps.get(i).model.backward_data.dL_dg
)
dGs[i] = torch.Tensor(
dGs[i] = torch.tensor(
ctx.vector_of_qps.get(i).model.backward_data.dL_dC
)
dus[i] = torch.Tensor(
dus[i] = torch.tensor(
ctx.vector_of_qps.get(i).model.backward_data.dL_du
)
dls[i] = torch.Tensor(
dls[i] = torch.tensor(
ctx.vector_of_qps.get(i).model.backward_data.dL_dl
)
dAs[i] = torch.Tensor(
dAs[i] = torch.tensor(
ctx.vector_of_qps.get(i).model.backward_data.dL_dA
)
dbs[i] = torch.Tensor(
dbs[i] = torch.tensor(
ctx.vector_of_qps.get(i).model.backward_data.dL_db
)

Expand Down

0 comments on commit 2cd1521

Please sign in to comment.