Skip to content

Commit

Permalink
refactor(repository): use native SQL for history
Browse files Browse the repository at this point in the history
Refactors PoliciesHistoryRepository to use native SQL to support
Postgresql-specific queries (like JSON lookups).

RHINENG-1191
  • Loading branch information
vkrizan committed Jul 19, 2023
1 parent 43cbdb4 commit 683774d
Showing 1 changed file with 37 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
import com.redhat.cloud.policies.app.model.pager.Pager;
import io.quarkus.logging.Log;
import io.quarkus.panache.common.Sort;

import org.hibernate.query.NativeQuery;
import org.hibernate.type.LongType;
import org.hibernate.Session;

import javax.enterprise.context.ApplicationScoped;
import javax.inject.Inject;
import javax.persistence.Table;
import javax.persistence.Query;
import javax.persistence.TypedQuery;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
Expand All @@ -23,48 +26,54 @@ public class PoliciesHistoryRepository {
@Inject
Session session;

private static final String tableName = PoliciesHistoryEntry.class.getAnnotation(Table.class).name();

public long count(String orgId, UUID policyId, Pager pager) {
// Base HQL query.
String hql = "SELECT COUNT(*) FROM PoliciesHistoryEntry WHERE orgId = :orgId AND policyId = :policyId";
// Base SQL query.
String sql = String.format("SELECT COUNT(*) AS count FROM %s WHERE org_id = :orgId AND policy_id = :policyId",
tableName);

hql = addFiltersConditions(hql, pager.getFilter().getItems());
sql = addFiltersConditions(sql, pager.getFilter().getItems());

Log.tracef("HQL query ready to be executed: %s", hql);
Log.tracef("SQL query ready to be executed: %s", sql);

TypedQuery<Long> query = session.createQuery(hql, Long.class)
NativeQuery<?> query = session.createNativeQuery(sql)
.addScalar("count", LongType.INSTANCE)
.setParameter("orgId", orgId)
.setParameter("policyId", policyId.toString());

setFiltersValues(query, pager.getFilter().getItems());

return query.getSingleResult();
return (Long) query.getSingleResult();
}

public List<PoliciesHistoryEntry> find(String orgId, UUID policyId, Pager pager) {
// Base HQL query.
String hql = "FROM PoliciesHistoryEntry WHERE orgId = :orgId AND policyId = :policyId";
// Base SQL query.
String sql = String.format("SELECT * FROM %s WHERE org_id = :orgId AND policy_id = :policyId",
tableName);

hql = addFiltersConditions(hql, pager.getFilter().getItems());
sql = addFiltersConditions(sql, pager.getFilter().getItems());

// The sorts from the pager are added to the HQL query.
if (!pager.getSort().getColumns().isEmpty()) {
List<String> orderByItems = new ArrayList<>();
for (Sort.Column column : pager.getSort().getColumns()) {
getEntityFieldName(column.getName()).ifPresent(entityFieldName -> {
getSortFieldName(column.getName()).ifPresent(entityFieldName -> {
String sortDirection = getSortDirection(column.getDirection());
orderByItems.add(entityFieldName + " " + sortDirection);
});
}
if (!orderByItems.isEmpty()) {
hql += " ORDER BY " + String.join(", ", orderByItems);
sql += " ORDER BY " + String.join(", ", orderByItems);
}
} else {
hql += " ORDER BY ctime DESC, hostName ASC";
sql += " ORDER BY ctime DESC, host_name ASC";
}

Log.tracef("HQL query ready to be executed: %s", hql);
Log.tracef("SQL query ready to be executed: %s", sql);

TypedQuery<PoliciesHistoryEntry> query = session.createQuery(hql, PoliciesHistoryEntry.class)
NativeQuery<PoliciesHistoryEntry> query = session
.createNativeQuery(sql, PoliciesHistoryEntry.class)
.setParameter("orgId", orgId)
.setParameter("policyId", policyId.toString());

Expand All @@ -80,20 +89,20 @@ public List<PoliciesHistoryEntry> find(String orgId, UUID policyId, Pager pager)
return query.getResultList();
}

private static String addFiltersConditions(String hql, List<Filter.FilterItem> filterItems) {
private static String addFiltersConditions(String sql, List<Filter.FilterItem> filterItems) {
// The filters from the pager are added to the HQL query.
for (Filter.FilterItem filterItem : filterItems) {
String entityFieldName = getEntityFieldName(filterItem);
String operator = getHqlOperator(filterItem);
String fieldName = getFieldName(filterItem);
String operator = getOperator(filterItem);
// To be consistent with the previous implementation, the condition is always case-insensitive.
hql += " AND LOWER(" + entityFieldName + ")" + operator + ":" + entityFieldName;
sql += " AND LOWER(" + fieldName + ")" + operator + ":" + fieldName;
}
return hql;
return sql;
}

private static void setFiltersValues(Query query, List<Filter.FilterItem> filterItems) {
for (Filter.FilterItem filterItem : filterItems) {
String paramName = getEntityFieldName(filterItem);
String paramName = getFieldName(filterItem);
String paramValue = filterItem.value.toString().toLowerCase();
if (filterItem.operator == LIKE) {
paramValue = "%" + paramValue + "%";
Expand All @@ -106,20 +115,20 @@ private static void setFiltersValues(Query query, List<Filter.FilterItem> filter
* The following static methods may look like simple mappers, but some of them are also used to prevent SQL
* injections by whitelisting field names.
*/
private static String getEntityFieldName(Filter.FilterItem filterItem) {
private static String getFieldName(Filter.FilterItem filterItem) {
switch (filterItem.field) {
case "id":
return "hostId";
return "host_id";
case "name":
return "hostName";
return "host_name";
case "ctime":
return "ctime";
default:
throw new IllegalArgumentException("Unknown filter field: " + filterItem.field);
}
}

private static String getHqlOperator(Filter.FilterItem filterItem) {
private static String getOperator(Filter.FilterItem filterItem) {
switch (filterItem.operator) {
case EQUAL:
return " = ";
Expand All @@ -132,12 +141,12 @@ private static String getHqlOperator(Filter.FilterItem filterItem) {
}
}

private static Optional<String> getEntityFieldName(String sortColumn) {
private static Optional<String> getSortFieldName(String sortColumn) {
switch (sortColumn) {
case "id":
return Optional.of("hostId");
return Optional.of("host_id");
case "name":
return Optional.of("hostName");
return Optional.of("host_name");
case "ctime":
return Optional.of("ctime");
case "mtime":
Expand Down

0 comments on commit 683774d

Please sign in to comment.