diff --git a/warehouse/query-core/src/main/java/datawave/query/tables/ssdeep/DiscoveredSSDeep.java b/warehouse/query-core/src/main/java/datawave/query/tables/ssdeep/DiscoveredSSDeep.java new file mode 100644 index 00000000000..0ffc8ef4a20 --- /dev/null +++ b/warehouse/query-core/src/main/java/datawave/query/tables/ssdeep/DiscoveredSSDeep.java @@ -0,0 +1,21 @@ +package datawave.query.tables.ssdeep; + +import datawave.query.discovery.DiscoveredThing; + +public class DiscoveredSSDeep { + public final ScoredSSDeepPair scoredSSDeepPair; + public final DiscoveredThing discoveredThing; + + public DiscoveredSSDeep(ScoredSSDeepPair scoredSSDeepPair, DiscoveredThing discoveredThing) { + this.scoredSSDeepPair = scoredSSDeepPair; + this.discoveredThing = discoveredThing; + } + + public ScoredSSDeepPair getScoredSSDeepPair() { + return scoredSSDeepPair; + } + + public DiscoveredThing getDiscoveredThing() { + return discoveredThing; + } +} diff --git a/warehouse/query-core/src/main/java/datawave/query/tables/ssdeep/FullSSDeepDiscoveryChainStrategy.java b/warehouse/query-core/src/main/java/datawave/query/tables/ssdeep/FullSSDeepDiscoveryChainStrategy.java new file mode 100644 index 00000000000..eca74a21f14 --- /dev/null +++ b/warehouse/query-core/src/main/java/datawave/query/tables/ssdeep/FullSSDeepDiscoveryChainStrategy.java @@ -0,0 +1,114 @@ +package datawave.query.tables.ssdeep; + +import java.util.Iterator; +import java.util.Set; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; + +import org.apache.accumulo.core.client.AccumuloClient; +import org.apache.accumulo.core.security.Authorizations; +import org.apache.log4j.Logger; + +import com.google.common.collect.Multimap; +import com.google.common.collect.TreeMultimap; + +import datawave.query.discovery.DiscoveredThing; +import datawave.query.tables.chained.strategy.FullChainStrategy; +import datawave.webservice.query.Query; +import datawave.webservice.query.QueryImpl; +import datawave.webservice.query.logic.QueryLogic; + +public class FullSSDeepDiscoveryChainStrategy extends FullChainStrategy { + private static final Logger log = Logger.getLogger(FullSSDeepDiscoveryChainStrategy.class); + + private Multimap scoredMatches; + + @Override + protected Query buildLatterQuery(Query initialQuery, Iterator initialQueryResults, String latterLogicName) { + log.debug("buildLatterQuery() called..."); + + // track the scored matches we've seen while traversing the initial query results. + // this has to be case-insensitive because the CHECKSUM_SSDEEP index entries are most likely downcased. + scoredMatches = TreeMultimap.create(String.CASE_INSENSITIVE_ORDER, ScoredSSDeepPair.NATURAL_ORDER); + + String queryString = captureScoredMatchesAndBuildQuery(initialQueryResults, scoredMatches); + + Query q = new QueryImpl(); // TODO, need to use a factory? don't hardcode this. + q.setQuery(queryString); + q.setId(UUID.randomUUID()); + q.setPagesize(Integer.MAX_VALUE); // TODO: choose something reasonable. + q.setQueryAuthorizations(initialQuery.getQueryAuthorizations()); + q.setUserDN(initialQuery.getUserDN()); + return q; + } + + @Override + public Iterator runChainedQuery(AccumuloClient client, Query initialQuery, Set auths, + Iterator initialQueryResults, QueryLogic latterQueryLogic) throws Exception { + final Iterator it = super.runChainedQuery(client, initialQuery, auths, initialQueryResults, latterQueryLogic); + + // Create a defensive copy of the score map because stream evaluation may be delayed. + final Multimap localScoredMatches = TreeMultimap.create(String.CASE_INSENSITIVE_ORDER, ScoredSSDeepPair.NATURAL_ORDER); + localScoredMatches.putAll(scoredMatches); + + return getEnrichedDiscoveredSSDeepIterator(it, localScoredMatches); + } + + /** + * + * @param initialQueryResults + * an iterator of scored ssdeep pairs that represent the results of the initial ssdeep similarity query. + * @param scoredMatches + * used to capture the scored matches contained within the initialQueryResults + * @return the query string for the next stage of the query. + */ + public static String captureScoredMatchesAndBuildQuery(Iterator initialQueryResults, + final Multimap scoredMatches) { + // extract the matched ssdeeps from the query results and generate the discovery query. + /* + * StringBuilder b = new StringBuilder(); Set ssdeepSeen = new HashSet<>(); while (initialQueryResults.hasNext()) { ScoredSSDeepPair result = + * initialQueryResults.next(); SSDeepHash matchingHash = result.getMatchingHash(); scoredMatches.put(matchingHash.toString(), result); String ssdeep = + * matchingHash.toString(); if (ssdeepSeen.contains(ssdeep)) { continue; } log.debug("Added new ssdeep " + ssdeep); ssdeepSeen.add(ssdeep); if + * (b.length() > 0) { b.append(" OR "); } b.append("CHECKSUM_SSDEEP:\"").append(ssdeep).append("\""); } return b.toString(); + */ + + return StreamSupport.stream(Spliterators.spliteratorUnknownSize(initialQueryResults, Spliterator.ORDERED), false) + .filter(queryResult -> scoredMatches.put(queryResult.getMatchingHash().toString(), queryResult)) + .map(queryResult -> queryResult.getMatchingHash().toString()).distinct().peek(ssdeep -> log.debug("Added new ssdeep " + ssdeep)) + .map(ssdeep -> "CHECKSUM_SSDEEP:\"" + ssdeep + "\"").collect(Collectors.joining(" OR ", "", "")); + } + + /** + * Given an iterator of DiscoveredSSDeep objects that have no matching query or weighted score, lookup the potential queries that returned them and the + * weighted score associated with that query and use them to produce enriched results. + * + * @param resultsIterator + * an iterator of unenrched DiscoveredSSDeep's that don't have query or score info. + * @param scoredMatches + * the colletion of matchin hashes and the original queries that lead them to be returned. + * @return an iterator of DiscoveredSSDeep's enriched with the queries that returned them. + */ + public static Iterator getEnrichedDiscoveredSSDeepIterator(Iterator resultsIterator, + final Multimap scoredMatches) { + return StreamSupport.stream(Spliterators.spliteratorUnknownSize(resultsIterator, Spliterator.ORDERED), false) + .flatMap(discoveredSSdeep -> enrichDiscoveredSSDeep(discoveredSSdeep, scoredMatches)).iterator(); + } + + /** + * Given a single discovered ssdeep, use the scoredMatches map to determine which queries it is related to. This will return zero to many new + * DiscoveredSSDeep entries for each query that the matching ssdeep hash appeared in. + * + * @param discoveredSSDeep + * @param scoredMatches + * @return + */ + public static Stream enrichDiscoveredSSDeep(DiscoveredSSDeep discoveredSSDeep, final Multimap scoredMatches) { + final DiscoveredThing discoveredThing = discoveredSSDeep.getDiscoveredThing(); + final String term = discoveredThing.getTerm(); + return scoredMatches.get(term).stream().map(scoredPair -> new DiscoveredSSDeep(scoredPair, discoveredThing)); + } +} diff --git a/warehouse/query-core/src/main/java/datawave/query/tables/ssdeep/SSDeepChainedDiscoveryQueryLogic.java b/warehouse/query-core/src/main/java/datawave/query/tables/ssdeep/SSDeepChainedDiscoveryQueryLogic.java new file mode 100644 index 00000000000..c9101b36eba --- /dev/null +++ b/warehouse/query-core/src/main/java/datawave/query/tables/ssdeep/SSDeepChainedDiscoveryQueryLogic.java @@ -0,0 +1,75 @@ +package datawave.query.tables.ssdeep; + +import java.util.Collections; +import java.util.Iterator; +import java.util.Set; + +import org.apache.accumulo.core.client.AccumuloClient; +import org.apache.accumulo.core.security.Authorizations; +import org.apache.log4j.Logger; + +import datawave.query.tables.chained.ChainedQueryTable; +import datawave.webservice.query.Query; +import datawave.webservice.query.configuration.GenericQueryConfiguration; +import datawave.webservice.query.logic.QueryLogicTransformer; + +public class SSDeepChainedDiscoveryQueryLogic extends ChainedQueryTable { + + private static final Logger log = Logger.getLogger(SSDeepChainedDiscoveryQueryLogic.class); + + private Query discoveryQuery = null; + + public SSDeepChainedDiscoveryQueryLogic() { + super(); + } + + @SuppressWarnings("CopyConstructorMissesField") + public SSDeepChainedDiscoveryQueryLogic(SSDeepChainedDiscoveryQueryLogic other) { + super(other); + } + + @Override + public void close() { + super.close(); + } + + public GenericQueryConfiguration initialize(AccumuloClient client, Query settings, Set auths) throws Exception { + super.initialize(client, settings, auths); + this.discoveryQuery = settings.duplicate(settings.getQueryName() + "_discovery_query"); + + log.debug("Initial settings parameters: " + settings.getParameters().toString()); + GenericQueryConfiguration config = this.logic1.initialize(client, settings, auths); + return config; + } + + public void setupQuery(GenericQueryConfiguration config) throws Exception { + if (null == this.getChainStrategy()) { + final String error = "No transformed ChainStrategy provided for SSDeepChainedDiscoveryQueryLogic!"; + log.error(error); + throw new RuntimeException(error); + } + + log.info("Setting up ssdeep query using config"); + this.logic1.setupQuery(config); + + final Iterator iter1 = this.logic1.iterator(); + + log.info("Running chained discovery query"); + this.iterator = this.getChainStrategy().runChainedQuery(config.getClient(), this.discoveryQuery, config.getAuthorizations(), iter1, this.logic2); + } + + @Override + public QueryLogicTransformer getTransformer(Query settings) { + return this.logic2.getTransformer(settings); + } + + @Override + public SSDeepChainedDiscoveryQueryLogic clone() throws CloneNotSupportedException { + return new SSDeepChainedDiscoveryQueryLogic(this); + } + + public Set getExampleQueries() { + return Collections.emptySet(); + } + +} diff --git a/warehouse/query-core/src/main/java/datawave/query/tables/ssdeep/SSDeepDiscoveryQueryLogic.java b/warehouse/query-core/src/main/java/datawave/query/tables/ssdeep/SSDeepDiscoveryQueryLogic.java new file mode 100644 index 00000000000..422996f7e9e --- /dev/null +++ b/warehouse/query-core/src/main/java/datawave/query/tables/ssdeep/SSDeepDiscoveryQueryLogic.java @@ -0,0 +1,462 @@ +package datawave.query.tables.ssdeep; + +import java.security.Principal; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import org.apache.accumulo.core.client.AccumuloClient; +import org.apache.accumulo.core.security.Authorizations; +import org.apache.commons.collections4.Transformer; +import org.apache.commons.collections4.iterators.TransformIterator; + +import datawave.audit.SelectorExtractor; +import datawave.marking.MarkingFunctions; +import datawave.query.discovery.DiscoveredThing; +import datawave.query.discovery.DiscoveryLogic; +import datawave.query.discovery.DiscoveryTransformer; +import datawave.query.util.MetadataHelperFactory; +import datawave.security.authorization.UserOperations; +import datawave.webservice.common.audit.Auditor; +import datawave.webservice.common.connection.AccumuloConnectionFactory; +import datawave.webservice.query.Query; +import datawave.webservice.query.configuration.GenericQueryConfiguration; +import datawave.webservice.query.exception.QueryException; +import datawave.webservice.query.iterator.DatawaveTransformIterator; +import datawave.webservice.query.logic.AbstractQueryLogicTransformer; +import datawave.webservice.query.logic.BaseQueryLogic; +import datawave.webservice.query.logic.QueryLogicTransformer; +import datawave.webservice.query.logic.ResponseEnricherBuilder; +import datawave.webservice.query.logic.RoleManager; +import datawave.webservice.query.result.event.EventBase; +import datawave.webservice.query.result.event.FieldBase; +import datawave.webservice.query.result.event.ResponseObjectFactory; +import datawave.webservice.result.BaseQueryResponse; + +public class SSDeepDiscoveryQueryLogic extends BaseQueryLogic { + public DiscoveryLogic discoveryDelegate; + + @SuppressWarnings("ConstantConditions") + public SSDeepDiscoveryQueryLogic() { + super(); + if (this.discoveryDelegate == null) { // may be set by super constructor + this.discoveryDelegate = new DiscoveryLogic(); + } + } + + public SSDeepDiscoveryQueryLogic(SSDeepDiscoveryQueryLogic other) { + super(other); + this.discoveryDelegate = (DiscoveryLogic) other.discoveryDelegate.clone(); + } + + @Override + public QueryLogicTransformer getTransformer(final Query settings) { + // TODO: implement our transformer here. + // return discoveryDelegate.getTransformer(settings); + + final DiscoveryTransformer discoveryTransformer = (DiscoveryTransformer) discoveryDelegate.getTransformer(settings); + QueryLogicTransformer ssdeepTransformer = new AbstractQueryLogicTransformer<>() { + @Override + public BaseQueryResponse createResponse(List resultList) { + return discoveryTransformer.createResponse(resultList); + } + + @Override + public EventBase transform(DiscoveredSSDeep discoveredSSDeep) { + EventBase eventBase = discoveryTransformer.transform(discoveredSSDeep.getDiscoveredThing()); + ResponseObjectFactory responseObjectFactory = discoveryDelegate.getResponseObjectFactory(); + ScoredSSDeepPair scoredSSDeepPair = discoveredSSDeep.getScoredSSDeepPair(); + if (scoredSSDeepPair != null) { + List> fields = eventBase.getFields(); + Optional> valueFieldOptional = fields.stream().filter(field -> "VALUE".equals(field.getName())).findFirst(); + + if (valueFieldOptional.isEmpty()) { + throw new IllegalStateException("Could not find value field in event"); + } + + FieldBase valueField = valueFieldOptional.get(); + + { + FieldBase field = responseObjectFactory.getField(); + field.setName("QUERY"); + field.setMarkings(valueField.getMarkings()); + field.setColumnVisibility(valueField.getColumnVisibility()); + field.setTimestamp(valueField.getTimestamp()); + field.setValue(scoredSSDeepPair.getQueryHash().toString()); + fields.add(field); + } + + { + FieldBase field = responseObjectFactory.getField(); + field.setName("WEIGHTED_SCORE"); + field.setMarkings(valueField.getMarkings()); + field.setColumnVisibility(valueField.getColumnVisibility()); + field.setTimestamp(valueField.getTimestamp()); + field.setValue(scoredSSDeepPair.getWeightedScore()); + fields.add(field); + } + + } + + return eventBase; + } + }; + return ssdeepTransformer; + + } + + @Override + public TransformIterator getTransformIterator(Query settings) { + return new DatawaveTransformIterator(this.iterator(), this.getTransformer(settings)); + } + + @Override + public Iterator iterator() { + // return discoveryDelegate.iterator(); + + return new TransformIterator(discoveryDelegate.iterator(), new Transformer() { + @Override + public DiscoveredSSDeep transform(DiscoveredThing o) { + return new DiscoveredSSDeep(null, o); + } + }); + + } + + // All delegate methods past this point // + + public void setTableName(String tableName) { + discoveryDelegate.setTableName(tableName); + } + + public void setIndexTableName(String tableName) { + discoveryDelegate.setIndexTableName(tableName); + } + + public void setReverseIndexTableName(String tableName) { + discoveryDelegate.setReverseIndexTableName(tableName); + } + + public void setModelTableName(String tableName) { + discoveryDelegate.setModelTableName(tableName); + } + + public void setMetadataHelperFactory(MetadataHelperFactory metadataHelperFactory) { + discoveryDelegate.setMetadataHelperFactory(metadataHelperFactory); + } + + public void setResponseObjectFactory(ResponseObjectFactory responseObjectFactory) { + discoveryDelegate.setResponseObjectFactory(responseObjectFactory); + } + + public void setMarkingFunctions(MarkingFunctions markingFunctions) { + discoveryDelegate.setMarkingFunctions(markingFunctions); + } + + @Override + public GenericQueryConfiguration initialize(AccumuloClient client, Query settings, Set runtimeQueryAuthorizations) throws Exception { + return discoveryDelegate.initialize(client, settings, runtimeQueryAuthorizations); + } + + @Override + public void setupQuery(GenericQueryConfiguration configuration) throws Exception { + discoveryDelegate.setupQuery(configuration); + } + + @Override + public Object clone() throws CloneNotSupportedException { + return new SSDeepDiscoveryQueryLogic(this); + } + + @Override + public AccumuloConnectionFactory.Priority getConnectionPriority() { + return discoveryDelegate.getConnectionPriority(); + } + + @Override + public Set getOptionalQueryParameters() { + return discoveryDelegate.getOptionalQueryParameters(); + } + + @Override + public Set getRequiredQueryParameters() { + return discoveryDelegate.getRequiredQueryParameters(); + } + + @Override + public Set getExampleQueries() { + return discoveryDelegate.getExampleQueries(); + } + + @Override + public GenericQueryConfiguration getConfig() { + if (discoveryDelegate == null) { + discoveryDelegate = new DiscoveryLogic(); + } + return discoveryDelegate.getConfig(); + } + + @Override + public String getPlan(AccumuloClient client, Query settings, Set runtimeQueryAuthorizations, boolean expandFields, boolean expandValues) + throws Exception { + return discoveryDelegate.getPlan(client, settings, runtimeQueryAuthorizations, expandFields, expandValues); + } + + @Override + public MarkingFunctions getMarkingFunctions() { + return discoveryDelegate.getMarkingFunctions(); + } + + @Override + public ResponseObjectFactory getResponseObjectFactory() { + return discoveryDelegate.getResponseObjectFactory(); + } + + @Override + public Principal getPrincipal() { + return discoveryDelegate.getPrincipal(); + } + + @Override + public void setPrincipal(Principal principal) { + discoveryDelegate.setPrincipal(principal); + } + + @Override + public String getTableName() { + return discoveryDelegate.getTableName(); + } + + @Override + public long getMaxResults() { + return discoveryDelegate.getMaxResults(); + } + + @Override + public long getMaxWork() { + return discoveryDelegate.getMaxWork(); + } + + @Override + public void setMaxResults(long maxResults) { + discoveryDelegate.setMaxResults(maxResults); + } + + @Override + public void setMaxWork(long maxWork) { + discoveryDelegate.setMaxWork(maxWork); + } + + @Override + public int getMaxPageSize() { + return discoveryDelegate.getMaxPageSize(); + } + + @Override + public void setMaxPageSize(int maxPageSize) { + discoveryDelegate.setMaxPageSize(maxPageSize); + } + + @Override + public long getPageByteTrigger() { + return discoveryDelegate.getPageByteTrigger(); + } + + @Override + public void setPageByteTrigger(long pageByteTrigger) { + discoveryDelegate.setPageByteTrigger(pageByteTrigger); + } + + @Override + public int getBaseIteratorPriority() { + return discoveryDelegate.getBaseIteratorPriority(); + } + + @Override + public void setBaseIteratorPriority(int baseIteratorPriority) { + discoveryDelegate.setBaseIteratorPriority(baseIteratorPriority); + } + + @Override + public String getLogicName() { + return discoveryDelegate.getLogicName(); + } + + @Override + public void setLogicName(String logicName) { + discoveryDelegate.setLogicName(logicName); + } + + @Override + public boolean getBypassAccumulo() { + return discoveryDelegate.getBypassAccumulo(); + } + + @Override + public void setBypassAccumulo(boolean bypassAccumulo) { + discoveryDelegate.setBypassAccumulo(bypassAccumulo); + } + + @Override + public String getAccumuloPassword() { + return discoveryDelegate.getAccumuloPassword(); + } + + @Override + public void setAccumuloPassword(String accumuloPassword) { + discoveryDelegate.setAccumuloPassword(accumuloPassword); + } + + @Override + public Auditor.AuditType getAuditType(Query query) { + return discoveryDelegate.getAuditType(query); + } + + @Override + public Auditor.AuditType getAuditType() { + return discoveryDelegate.getAuditType(); + } + + @Override + public void setAuditType(Auditor.AuditType auditType) { + discoveryDelegate.setAuditType(auditType); + } + + @Override + public void setLogicDescription(String logicDescription) { + discoveryDelegate.setLogicDescription(logicDescription); + } + + @Override + public String getLogicDescription() { + return discoveryDelegate.getLogicDescription(); + } + + @Override + public boolean getCollectQueryMetrics() { + return discoveryDelegate.getCollectQueryMetrics(); + } + + @Override + public void setCollectQueryMetrics(boolean collectQueryMetrics) { + discoveryDelegate.setCollectQueryMetrics(collectQueryMetrics); + } + + @Override + public RoleManager getRoleManager() { + return discoveryDelegate.getRoleManager(); + } + + @Override + public void setRoleManager(RoleManager roleManager) { + discoveryDelegate.setRoleManager(roleManager); + } + + @Override + public String getConnPoolName() { + return discoveryDelegate.getConnPoolName(); + } + + @Override + public void setConnPoolName(String connPoolName) { + discoveryDelegate.setConnPoolName(connPoolName); + } + + @Override + public boolean canRunQuery() { + return discoveryDelegate.canRunQuery(); + } + + @Override + public boolean canRunQuery(Principal principal) { + return discoveryDelegate.canRunQuery(principal); + } + + @Override + public List getSelectors(Query settings) throws IllegalArgumentException { + return discoveryDelegate.getSelectors(settings); + } + + @Override + public void setSelectorExtractor(SelectorExtractor selectorExtractor) { + discoveryDelegate.setSelectorExtractor(selectorExtractor); + } + + @Override + public SelectorExtractor getSelectorExtractor() { + return discoveryDelegate.getSelectorExtractor(); + } + + @Override + public Set getAuthorizedDNs() { + return discoveryDelegate.getAuthorizedDNs(); + } + + @Override + public void setAuthorizedDNs(Set authorizedDNs) { + discoveryDelegate.setAuthorizedDNs(authorizedDNs); + } + + @Override + public void setDnResultLimits(Map dnResultLimits) { + discoveryDelegate.setDnResultLimits(dnResultLimits); + } + + @Override + public Map getDnResultLimits() { + return discoveryDelegate.getDnResultLimits(); + } + + @Override + public void setSystemFromResultLimits(Map systemFromLimits) { + discoveryDelegate.setSystemFromResultLimits(systemFromLimits); + } + + @Override + public Map getSystemFromResultLimits() { + return discoveryDelegate.getSystemFromResultLimits(); + } + + @Override + public void setPageProcessingStartTime(long pageProcessingStartTime) { + discoveryDelegate.setPageProcessingStartTime(pageProcessingStartTime); + } + + @Override + public boolean isLongRunningQuery() { + return discoveryDelegate.isLongRunningQuery(); + } + + @Override + public ResponseEnricherBuilder getResponseEnricherBuilder() { + return discoveryDelegate.getResponseEnricherBuilder(); + } + + @Override + public void setResponseEnricherBuilder(ResponseEnricherBuilder responseEnricherBuilder) { + discoveryDelegate.setResponseEnricherBuilder(responseEnricherBuilder); + } + + @Override + public UserOperations getUserOperations() { + return discoveryDelegate.getUserOperations(); + } + + @Override + public String getResponseClass(Query query) throws QueryException { + return discoveryDelegate.getResponseClass(query); + } + + @Override + public boolean containsDNWithAccess(Collection dns) { + return discoveryDelegate.containsDNWithAccess(dns); + } + + @Override + public long getResultLimit(Query settings) { + return discoveryDelegate.getResultLimit(settings); + } +} diff --git a/warehouse/query-core/src/test/java/datawave/query/tables/ssdeep/SSDeepIngestQueryTest.java b/warehouse/query-core/src/test/java/datawave/query/tables/ssdeep/SSDeepIngestQueryTest.java index 2a8d1fe8417..c3efe748a8b 100644 --- a/warehouse/query-core/src/test/java/datawave/query/tables/ssdeep/SSDeepIngestQueryTest.java +++ b/warehouse/query-core/src/test/java/datawave/query/tables/ssdeep/SSDeepIngestQueryTest.java @@ -37,6 +37,7 @@ import datawave.query.testframework.FieldConfig; import datawave.query.testframework.FileType; import datawave.query.testframework.QueryLogicTestHarness; +import datawave.query.util.MetadataHelperFactory; import datawave.security.authorization.DatawavePrincipal; import datawave.security.authorization.DatawaveUser; import datawave.security.authorization.SubjectIssuerDNPair; @@ -60,6 +61,10 @@ public class SSDeepIngestQueryTest extends AbstractFunctionalQuery { SSDeepSimilarityQueryLogic similarityQueryLogic; + SSDeepDiscoveryQueryLogic discoveryQueryLogic; + + SSDeepChainedDiscoveryQueryLogic similarityDiscoveryQueryLogic; + @BeforeClass public static void filterSetup() throws Exception { log.setLevel(Level.DEBUG); @@ -80,6 +85,7 @@ public static void filterSetup() throws Exception { public void setupQuery() { MarkingFunctions markingFunctions = new MarkingFunctions.Default(); ResponseObjectFactory responseFactory = new DefaultResponseObjectFactory(); + MetadataHelperFactory metadataHelperFactory = new MetadataHelperFactory(); similarityQueryLogic = new SSDeepSimilarityQueryLogic(); similarityQueryLogic.setTableName("ssdeepIndex"); @@ -89,6 +95,26 @@ public void setupQuery() { similarityQueryLogic.setBucketEncodingLength(BUCKET_ENCODING_LENGTH); similarityQueryLogic.setIndexBuckets(BUCKET_COUNT); + discoveryQueryLogic = new SSDeepDiscoveryQueryLogic(); + discoveryQueryLogic.setTableName("shardIndex"); + discoveryQueryLogic.setIndexTableName("shardIndex"); + discoveryQueryLogic.setReverseIndexTableName("shardReverseIndex"); + discoveryQueryLogic.setModelTableName("metadata"); + discoveryQueryLogic.setMarkingFunctions(markingFunctions); + discoveryQueryLogic.setMetadataHelperFactory(metadataHelperFactory); + discoveryQueryLogic.setResponseObjectFactory(responseFactory); + + // TODO: This implementation works for now, but will likely not scale. + FullSSDeepDiscoveryChainStrategy ssdeepDiscoveryChainStrategy = new FullSSDeepDiscoveryChainStrategy(); + + // TODO: eliminate duplication in SSDeepChainedDiscoveryQueryLogic and SSDeepChainedEventQueryLogic + // also eliminate duplication in FullSSDeepDiscoveryChainStrategy and FullSSDeepEventChainStrategy. + similarityDiscoveryQueryLogic = new SSDeepChainedDiscoveryQueryLogic(); + similarityDiscoveryQueryLogic.setTableName("ssdeepIndex"); + similarityDiscoveryQueryLogic.setLogic1(similarityQueryLogic); + similarityDiscoveryQueryLogic.setLogic2(discoveryQueryLogic); + similarityDiscoveryQueryLogic.setChainStrategy(ssdeepDiscoveryChainStrategy); + // init must set auths testInit(); @@ -124,6 +150,56 @@ public void testSSDeepSimilarity() throws Exception { SSDeepTestUtil.assertSSDeepSimilarityMatch(testSSDeep, testSSDeep, "38", "100", observedEvents); } + @Test + public void testSSDeepDiscovery() throws Exception { + log.info("------ testDiscovery ------"); + String testSSDeep = "384:nv/fP9FmWVMdRFj2aTgSO+u5QT4ZE1PIVS:nDmWOdRFNTTs504cQS"; + String query = "CHECKSUM_SSDEEP:\"" + testSSDeep + "\""; + EventQueryResponseBase response = runSSDeepQuery(query, discoveryQueryLogic, 0); + + List events = response.getEvents(); + Assert.assertEquals(1, events.size()); + Map> observedEvents = extractObservedEvents(events); + + Map.Entry> result = observedEvents.entrySet().iterator().next(); + Map resultFields = result.getValue(); + Assert.assertEquals(testSSDeep, resultFields.get("VALUE")); + Assert.assertEquals("CHECKSUM_SSDEEP", resultFields.get("FIELD")); + Assert.assertEquals("20201031", resultFields.get("DATE")); + Assert.assertEquals("ssdeep", resultFields.get("DATA TYPE")); + Assert.assertEquals("4", resultFields.get("RECORD COUNT")); + + // At this point, the results have not been enriched with these fields, so they should not exist. + Assert.assertNull(null, resultFields.get("QUERY")); + Assert.assertNull(null, resultFields.get("WEIGHTED_SCORE")); + } + + @Test + public void testChainedSSDeepDiscovery() throws Exception { + log.info("------ testSSDeepDiscovery ------"); + String testSSDeep = "384:nv/fP9FmWVMdRFj2aTgSO+u5QT4ZE1PIVS:nDmWOdRFNTTs504---"; + String targetSSDeep = "384:nv/fP9FmWVMdRFj2aTgSO+u5QT4ZE1PIVS:nDmWOdRFNTTs504cQS"; + String query = "CHECKSUM_SSDEEP:" + testSSDeep; + EventQueryResponseBase response = runSSDeepQuery(query, similarityDiscoveryQueryLogic, 0); + + List events = response.getEvents(); + Assert.assertEquals(1, events.size()); + Map> observedEvents = extractObservedEvents(events); + + Map.Entry> result = observedEvents.entrySet().iterator().next(); + Map resultFields = result.getValue(); + Assert.assertEquals(targetSSDeep, resultFields.get("VALUE")); + + Assert.assertEquals("CHECKSUM_SSDEEP", resultFields.get("FIELD")); + Assert.assertEquals("20201031", resultFields.get("DATE")); + Assert.assertEquals("ssdeep", resultFields.get("DATA TYPE")); + Assert.assertEquals("4", resultFields.get("RECORD COUNT")); + + // The results have been enriched with these fields at this point. + Assert.assertEquals(testSSDeep, resultFields.get("QUERY")); + Assert.assertEquals("100", resultFields.get("WEIGHTED_SCORE")); + } + @SuppressWarnings("rawtypes") public EventQueryResponseBase runSSDeepQuery(String query, QueryLogic queryLogic, int minScoreThreshold) throws Exception { QueryImpl q = new QueryImpl();