From 676c15878d1dad6282cafaa7ca310e7df494b12e Mon Sep 17 00:00:00 2001 From: Drew Farris Date: Mon, 2 Oct 2023 11:22:19 -0400 Subject: [PATCH] SSDeepSimilarityQueryLogic implementation (#2085) Provides the ability to query an index of bucketized SSDeep ngrams See: https://github.com/NationalSecurityAgency/datawave/wiki/SSDeep-In-Datawave --- .../SSDeepSimilarityQueryConfiguration.java | 115 ++++++++ .../tables/SSDeepSimilarityQueryLogic.java | 213 +++++++++++++++ .../SSDeepSimilarityQueryTransformer.java | 199 ++++++++++++++ .../ssdeep/BucketAccumuloKeyGenerator.java | 98 +++++++ .../query/util/ssdeep/ChunkSizeEncoding.java | 89 +++++++ .../query/util/ssdeep/HashReverse.java | 106 ++++++++ .../query/util/ssdeep/IntegerEncoding.java | 160 +++++++++++ .../util/ssdeep/NGramByteHashGenerator.java | 23 ++ .../query/util/ssdeep/NGramGenerator.java | 142 ++++++++++ .../query/util/ssdeep/NGramScoreTuple.java | 51 ++++ .../query/util/ssdeep/NGramTuple.java | 68 +++++ .../query/util/ssdeep/SSDeepEncoding.java | 21 ++ .../query/util/ssdeep/SSDeepHash.java | 249 ++++++++++++++++++ .../query/util/ssdeep/SSDeepHashParser.java | 21 ++ .../util/ssdeep/SSDeepParseException.java | 23 ++ .../java/datawave/query/SSDeepQueryTest.java | 205 ++++++++++++++ .../SSDeepSimilarityQueryTransformerTest.java | 96 +++++++ .../query/SSDeepQueryLogicFactory.xml | 21 ++ 18 files changed, 1900 insertions(+) create mode 100644 warehouse/query-core/src/main/java/datawave/query/config/SSDeepSimilarityQueryConfiguration.java create mode 100644 warehouse/query-core/src/main/java/datawave/query/tables/SSDeepSimilarityQueryLogic.java create mode 100644 warehouse/query-core/src/main/java/datawave/query/transformer/SSDeepSimilarityQueryTransformer.java create mode 100644 warehouse/query-core/src/main/java/datawave/query/util/ssdeep/BucketAccumuloKeyGenerator.java create mode 100644 warehouse/query-core/src/main/java/datawave/query/util/ssdeep/ChunkSizeEncoding.java create mode 100644 warehouse/query-core/src/main/java/datawave/query/util/ssdeep/HashReverse.java create mode 100644 warehouse/query-core/src/main/java/datawave/query/util/ssdeep/IntegerEncoding.java create mode 100644 warehouse/query-core/src/main/java/datawave/query/util/ssdeep/NGramByteHashGenerator.java create mode 100644 warehouse/query-core/src/main/java/datawave/query/util/ssdeep/NGramGenerator.java create mode 100644 warehouse/query-core/src/main/java/datawave/query/util/ssdeep/NGramScoreTuple.java create mode 100644 warehouse/query-core/src/main/java/datawave/query/util/ssdeep/NGramTuple.java create mode 100644 warehouse/query-core/src/main/java/datawave/query/util/ssdeep/SSDeepEncoding.java create mode 100644 warehouse/query-core/src/main/java/datawave/query/util/ssdeep/SSDeepHash.java create mode 100644 warehouse/query-core/src/main/java/datawave/query/util/ssdeep/SSDeepHashParser.java create mode 100644 warehouse/query-core/src/main/java/datawave/query/util/ssdeep/SSDeepParseException.java create mode 100644 warehouse/query-core/src/test/java/datawave/query/SSDeepQueryTest.java create mode 100644 warehouse/query-core/src/test/java/datawave/query/transformer/SSDeepSimilarityQueryTransformerTest.java create mode 100644 web-services/deploy/configuration/src/main/resources/datawave/query/SSDeepQueryLogicFactory.xml diff --git a/warehouse/query-core/src/main/java/datawave/query/config/SSDeepSimilarityQueryConfiguration.java b/warehouse/query-core/src/main/java/datawave/query/config/SSDeepSimilarityQueryConfiguration.java new file mode 100644 index 00000000000..d0145c07119 --- /dev/null +++ b/warehouse/query-core/src/main/java/datawave/query/config/SSDeepSimilarityQueryConfiguration.java @@ -0,0 +1,115 @@ +package datawave.query.config; + +import java.util.Collection; + +import org.apache.accumulo.core.data.Range; + +import com.google.common.collect.Multimap; + +import datawave.query.util.ssdeep.BucketAccumuloKeyGenerator; +import datawave.query.util.ssdeep.ChunkSizeEncoding; +import datawave.query.util.ssdeep.IntegerEncoding; +import datawave.query.util.ssdeep.NGramTuple; +import datawave.query.util.ssdeep.SSDeepHash; +import datawave.webservice.query.Query; +import datawave.webservice.query.QueryImpl; +import datawave.webservice.query.configuration.GenericQueryConfiguration; +import datawave.webservice.query.logic.BaseQueryLogic; + +public class SSDeepSimilarityQueryConfiguration extends GenericQueryConfiguration { + + int queryThreads = 100; + int maxRepeatedCharacters = 3; + + int indexBuckets = BucketAccumuloKeyGenerator.DEFAULT_BUCKET_COUNT; + int bucketEncodingBase = BucketAccumuloKeyGenerator.DEFAULT_BUCKET_ENCODING_BASE; + int bucketEncodingLength = BucketAccumuloKeyGenerator.DEFAULT_BUCKET_ENCODING_LENGTH; + + /** Used to encode buckets as characters which are prepended to the ranges used to retrieve ngram tuples */ + private IntegerEncoding bucketEncoder; + /** Used to encode the chunk size as a character which is included in the ranges used to retrieve ngram tuples */ + private ChunkSizeEncoding chunkSizeEncoder; + + private Query query; + + private Collection ranges; + + private Multimap queryMap; + + public SSDeepSimilarityQueryConfiguration() { + super(); + query = new QueryImpl(); + } + + public SSDeepSimilarityQueryConfiguration(BaseQueryLogic configuredLogic) { + super(configuredLogic); + } + + public static SSDeepSimilarityQueryConfiguration create() { + return new SSDeepSimilarityQueryConfiguration(); + } + + public Query getQuery() { + return query; + } + + public void setQuery(Query query) { + this.query = query; + } + + public Collection getRanges() { + return ranges; + } + + public void setRanges(Collection ranges) { + this.ranges = ranges; + } + + public Multimap getQueryMap() { + return queryMap; + } + + public void setQueryMap(Multimap queryMap) { + this.queryMap = queryMap; + } + + public int getIndexBuckets() { + return indexBuckets; + } + + public void setIndexBuckets(int indexBuckets) { + this.indexBuckets = indexBuckets; + } + + public int getQueryThreads() { + return queryThreads; + } + + public void setQueryThreads(int queryThreads) { + this.queryThreads = queryThreads; + } + + public int getMaxRepeatedCharacters() { + return maxRepeatedCharacters; + } + + public void setMaxRepeatedCharacters(int maxRepeatedCharacters) { + this.maxRepeatedCharacters = maxRepeatedCharacters; + } + + public int getBucketEncodingBase() { + return bucketEncodingBase; + } + + public void setBucketEncodingBase(int bucketEncodingBase) { + this.bucketEncodingBase = bucketEncodingBase; + } + + public int getBucketEncodingLength() { + return bucketEncodingLength; + } + + public void setBucketEncodingLength(int bucketEncodingLength) { + this.bucketEncodingLength = bucketEncodingLength; + } +} diff --git a/warehouse/query-core/src/main/java/datawave/query/tables/SSDeepSimilarityQueryLogic.java b/warehouse/query-core/src/main/java/datawave/query/tables/SSDeepSimilarityQueryLogic.java new file mode 100644 index 00000000000..c6cefee8d12 --- /dev/null +++ b/warehouse/query-core/src/main/java/datawave/query/tables/SSDeepSimilarityQueryLogic.java @@ -0,0 +1,213 @@ +package datawave.query.tables; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import java.util.Set; +import java.util.TreeSet; +import java.util.stream.Collectors; + +import org.apache.accumulo.core.client.AccumuloClient; +import org.apache.accumulo.core.client.BatchScanner; +import org.apache.accumulo.core.client.ScannerBase; +import org.apache.accumulo.core.client.TableNotFoundException; +import org.apache.accumulo.core.data.Key; +import org.apache.accumulo.core.data.Range; +import org.apache.accumulo.core.data.Value; +import org.apache.accumulo.core.security.Authorizations; +import org.apache.hadoop.io.Text; +import org.apache.log4j.Logger; + +import com.google.common.collect.Multimap; + +import datawave.query.config.SSDeepSimilarityQueryConfiguration; +import datawave.query.transformer.SSDeepSimilarityQueryTransformer; +import datawave.query.util.ssdeep.ChunkSizeEncoding; +import datawave.query.util.ssdeep.IntegerEncoding; +import datawave.query.util.ssdeep.NGramGenerator; +import datawave.query.util.ssdeep.NGramTuple; +import datawave.query.util.ssdeep.SSDeepHash; +import datawave.webservice.common.connection.AccumuloConnectionFactory; +import datawave.webservice.query.Query; +import datawave.webservice.query.configuration.GenericQueryConfiguration; +import datawave.webservice.query.exception.QueryException; +import datawave.webservice.query.logic.BaseQueryLogic; +import datawave.webservice.query.logic.QueryLogicTransformer; + +public class SSDeepSimilarityQueryLogic extends BaseQueryLogic> { + + private static final Logger log = Logger.getLogger(SSDeepSimilarityQueryLogic.class); + + private SSDeepSimilarityQueryConfiguration config; + + ScannerFactory scannerFactory; + + public SSDeepSimilarityQueryLogic() { + super(); + } + + public SSDeepSimilarityQueryLogic(final SSDeepSimilarityQueryLogic ssDeepSimilarityTable) { + super(ssDeepSimilarityTable); + this.config = ssDeepSimilarityTable.config; + this.scannerFactory = ssDeepSimilarityTable.scannerFactory; + } + + @Override + public SSDeepSimilarityQueryConfiguration getConfig() { + if (config == null) { + config = SSDeepSimilarityQueryConfiguration.create(); + } + return config; + } + + @Override + public GenericQueryConfiguration initialize(AccumuloClient accumuloClient, Query settings, Set auths) throws Exception { + final SSDeepSimilarityQueryConfiguration config = getConfig(); + + this.scannerFactory = new ScannerFactory(accumuloClient); + + config.setQuery(settings); + config.setClient(accumuloClient); + config.setAuthorizations(auths); + setupRanges(settings, config); + return config; + } + + @Override + public void setupQuery(GenericQueryConfiguration genericConfig) throws Exception { + if (!(genericConfig instanceof SSDeepSimilarityQueryConfiguration)) { + throw new QueryException("Did not receive a SSDeepSimilarityQueryConfiguration instance!!"); + } + + this.config = (SSDeepSimilarityQueryConfiguration) genericConfig; + + try { + final BatchScanner scanner = this.scannerFactory.newScanner(config.getTableName(), config.getAuthorizations(), config.getQueryThreads(), + config.getQuery()); + scanner.setRanges(config.getRanges()); + this.iterator = scanner.iterator(); + this.scanner = scanner; + + } catch (TableNotFoundException e) { + throw new RuntimeException("Table not found: " + this.getTableName(), e); + } + } + + /** + * Process the query to create the ngrams for the ranges to scan in accumulo. Store these in the configs along with a map that can be used to identify which + * SSDeepHash each query ngram originated from. + * + * @param settings + * the query we will be running. + * @param config + * write ranges and query map to this object. + */ + public void setupRanges(Query settings, SSDeepSimilarityQueryConfiguration config) { + final String query = settings.getQuery().trim(); + Set queries = Arrays.stream(query.split(" OR ")).map(k -> { + final int pos = k.indexOf(":"); + return pos > 0 ? k.substring(pos + 1) : k; + }).map(SSDeepHash::parse).collect(Collectors.toSet()); + + final NGramGenerator nGramEngine = new NGramGenerator(7); + log.info("Pre-processing " + queries.size() + " SSDeepHash queries"); + final int maxRepeatedCharacters = config.getMaxRepeatedCharacters(); + if (maxRepeatedCharacters > 0) { + log.info("Normalizing SSDeepHashes to remove long runs of consecutive characters"); + queries = queries.stream().map(h -> h.normalize(maxRepeatedCharacters)).collect(Collectors.toSet()); + } + + final Multimap queryMap = nGramEngine.preprocessQueries(queries); + final Set ranges = new TreeSet<>(); + + final IntegerEncoding bucketEncoder = new IntegerEncoding(config.getBucketEncodingBase(), config.getBucketEncodingLength()); + final ChunkSizeEncoding chunkSizeEncoder = new ChunkSizeEncoding(); + + final int indexBuckets = config.getIndexBuckets(); + + for (NGramTuple ct : queryMap.keys()) { + final String sizeAndChunk = chunkSizeEncoder.encode(ct.getChunkSize()) + ct.getChunk(); + for (int i = 0; i < indexBuckets; i++) { + final String bucketedSizeAndChunk = bucketEncoder.encode(i) + sizeAndChunk; + ranges.add(Range.exact(new Text(bucketedSizeAndChunk))); + } + } + + log.info("Generated " + queryMap.size() + " SSDeepHash ngrams of size " + nGramEngine.getNgramSize() + " and " + ranges.size() + " ranges. "); + if (log.isDebugEnabled()) { + log.debug("Query map is: " + queryMap); + log.debug("Ranges are: " + ranges); + } + config.setRanges(ranges); + config.setQueryMap(queryMap); + } + + @Override + public Object clone() throws CloneNotSupportedException { + return new SSDeepSimilarityQueryLogic(this); + } + + @Override + public void close() { + super.close(); + final ScannerFactory factory = this.scannerFactory; + if (null == factory) { + log.debug("ScannerFactory is null; not closing it."); + } else { + int nClosed = 0; + factory.lockdown(); + for (final ScannerBase bs : factory.currentScanners()) { + factory.close(bs); + ++nClosed; + } + if (log.isDebugEnabled()) + log.debug("Cleaned up " + nClosed + " batch scanners associated with this query logic."); + } + } + + @Override + public AccumuloConnectionFactory.Priority getConnectionPriority() { + return AccumuloConnectionFactory.Priority.NORMAL; + } + + @Override + public QueryLogicTransformer getTransformer(Query settings) { + final SSDeepSimilarityQueryConfiguration config = getConfig(); + return new SSDeepSimilarityQueryTransformer(settings, config, this.markingFunctions, this.responseObjectFactory); + } + + @Override + public Set getOptionalQueryParameters() { + return Collections.emptySet(); + } + + @Override + public Set getRequiredQueryParameters() { + return Collections.emptySet(); + } + + @Override + public Set getExampleQueries() { + return Collections.emptySet(); + } + + public void setIndexBuckets(int indexBuckets) { + getConfig().setIndexBuckets(indexBuckets); + } + + public void setQueryThreads(int queryThreads) { + getConfig().setQueryThreads(queryThreads); + } + + public void setMaxRepeatedCharacters(int maxRepeatedCharacters) { + getConfig().setMaxRepeatedCharacters(maxRepeatedCharacters); + } + + public void setBucketEncodingBase(int bucketEncodingBase) { + getConfig().setBucketEncodingBase(bucketEncodingBase); + } + + public void setBucketEncodingLength(int bucketEncodingLength) { + getConfig().setBucketEncodingLength(bucketEncodingLength); + } +} diff --git a/warehouse/query-core/src/main/java/datawave/query/transformer/SSDeepSimilarityQueryTransformer.java b/warehouse/query-core/src/main/java/datawave/query/transformer/SSDeepSimilarityQueryTransformer.java new file mode 100644 index 00000000000..74bd1806d94 --- /dev/null +++ b/warehouse/query-core/src/main/java/datawave/query/transformer/SSDeepSimilarityQueryTransformer.java @@ -0,0 +1,199 @@ +package datawave.query.transformer; + +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; + +import org.apache.accumulo.core.data.Key; +import org.apache.accumulo.core.data.Value; +import org.apache.accumulo.core.security.Authorizations; +import org.apache.log4j.Logger; + +import com.google.common.collect.Multimap; +import com.google.common.collect.TreeMultimap; + +import datawave.marking.MarkingFunctions; +import datawave.query.config.SSDeepSimilarityQueryConfiguration; +import datawave.query.util.ssdeep.ChunkSizeEncoding; +import datawave.query.util.ssdeep.IntegerEncoding; +import datawave.query.util.ssdeep.NGramScoreTuple; +import datawave.query.util.ssdeep.NGramTuple; +import datawave.query.util.ssdeep.SSDeepHash; +import datawave.webservice.query.Query; +import datawave.webservice.query.exception.EmptyObjectException; +import datawave.webservice.query.logic.BaseQueryLogicTransformer; +import datawave.webservice.query.result.event.EventBase; +import datawave.webservice.query.result.event.FieldBase; +import datawave.webservice.query.result.event.ResponseObjectFactory; +import datawave.webservice.result.BaseQueryResponse; +import datawave.webservice.result.EventQueryResponseBase; + +public class SSDeepSimilarityQueryTransformer extends BaseQueryLogicTransformer,Map.Entry> { + + private static final Logger log = Logger.getLogger(SSDeepSimilarityQueryTransformer.class); + + protected final Authorizations auths; + + protected final ResponseObjectFactory responseObjectFactory; + + /** Used to encode the chunk size as a character which is included in the ranges used to retrieve ngram tuples */ + final ChunkSizeEncoding chunkSizeEncoding; + + /** Used to encode buckets as characters which are prepended to the ranges used to retrieve ngram tuples */ + final IntegerEncoding bucketEncoder; + + /** + * the position where the ngram will start in the generated ranges, determined at construction time based on the bucketEncoder parameters + */ + final int chunkStart; + /** + * the position where the query ngram will end in the generate ranges, determined based on the bucketEncoder and chunkSizeEncoding parameters + */ + + final int chunkEnd; + + /** Tracks which ssdeep hashes each of the ngrams originated from */ + final Multimap queryMap; + + public SSDeepSimilarityQueryTransformer(Query query, SSDeepSimilarityQueryConfiguration config, MarkingFunctions markingFunctions, + ResponseObjectFactory responseObjectFactory) { + super(markingFunctions); + this.auths = new Authorizations(query.getQueryAuthorizations().split(",")); + this.queryMap = config.getQueryMap(); + this.responseObjectFactory = responseObjectFactory; + + this.bucketEncoder = new IntegerEncoding(config.getBucketEncodingBase(), config.getBucketEncodingLength()); + this.chunkSizeEncoding = new ChunkSizeEncoding(); + + this.chunkStart = bucketEncoder.getLength(); + this.chunkEnd = chunkStart + chunkSizeEncoding.getLength(); + } + + @Override + public BaseQueryResponse createResponse(List resultList) { + Multimap mm = TreeMultimap.create(); + for (Object o : resultList) { + Map.Entry e = (Map.Entry) o; + mm.put(e.getKey(), e.getValue()); + } + Multimap scoreTuples = scoreQuery(queryMap, mm); + return generateResponseFromScores(scoreTuples); + } + + public BaseQueryResponse generateResponseFromScores(Multimap scoreTuples) { + // package the scoredTuples into an event query response + final EventQueryResponseBase eventResponse = responseObjectFactory.getEventQueryResponse(); + final List events = new ArrayList<>(); + + SSDeepHash lastHash = null; + int rank = 1; + for (Map.Entry e : scoreTuples.entries()) { + if (!e.getKey().equals(lastHash)) { + log.info("New query hash: " + e.getKey()); + rank = 1; + lastHash = e.getKey(); + } + + final EventBase event = responseObjectFactory.getEvent(); + final List fields = new ArrayList<>(); + + FieldBase f = responseObjectFactory.getField(); + f.setName("MATCHING_SSDEEP"); + f.setValue(e.getValue().getSsDeepHash().toString()); + fields.add(f); + + f = responseObjectFactory.getField(); + f.setName("QUERY_SSDEEP"); + f.setValue(lastHash.toString()); + fields.add(f); + + f = responseObjectFactory.getField(); + f.setName("MATCH_SCORE"); + f.setValue(String.valueOf(e.getValue().getScore())); + fields.add(f); + + f = responseObjectFactory.getField(); + f.setName("MATCH_RANK"); + f.setValue(String.valueOf(rank)); + fields.add(f); + + event.setFields(fields); + events.add(event); + + log.info(" " + rank + ". " + e.getValue()); + rank++; + } + + eventResponse.setEvents(events); + + return eventResponse; + } + + @Override + public Map.Entry transform(Map.Entry input) throws EmptyObjectException { + // We will receive entries like: + // +++//thPkK 3:3:yionv//thPkKlDtn/rXScG2/uDlhl2UE9FQEul/lldDpZflsup:6v/lhPkKlDtt/6TIPFQEqRDpZ+up [] + + final Key k = input.getKey(); + final String row = k.getRow().toString(); + + // extract the matching ngram and chunk size from the rowId. + int chunkSize = chunkSizeEncoding.decode(row.substring(chunkStart, chunkEnd)); + String ngram = row.substring(chunkEnd); + + final NGramTuple c = new NGramTuple(chunkSize, ngram); + + // extract the matching ssdeep hash from the column qualifier + final String s = k.getColumnQualifier().toString(); + try { + final SSDeepHash h = SSDeepHash.parse(s); + return new AbstractMap.SimpleImmutableEntry<>(h, c); + } catch (Exception ioe) { + log.warn(ioe.getMessage() + " when parsing: " + s); + } + + return null; + } + + /** + * Given query ngrams and matching ssdeep hashes, return a scored set of ssdeep hashes that match the query and their accompanying scores. + * + * @param queryMap + * a map of ngrams to the query ssdeep hashes from which they originate. + * @param chunkPostings + * a map of matching ssdeep hashes liked to the ngram tuple that was matched. + * @return a map of ssdeep hashes to score tuples. + */ + protected Multimap scoreQuery(Multimap queryMap, Multimap chunkPostings) { + // score based on chunk match count + final Map> scoredHashMatches = new TreeMap<>(); + + // align the chunk postings to their original query ssdeep hash and count the number of matches + // for each chunk that corresponds to that original ssdeep hash + chunkPostings.asMap().forEach((hash, cpc) -> { + log.trace("Posting " + hash + " had " + cpc.size() + "chunk tuples"); + cpc.forEach(ct -> { + Collection ssdhc = queryMap.get(ct); + log.trace("Chunk tuple " + ct + " had " + ssdhc.size() + "related query hashes"); + ssdhc.forEach(ssdh -> { + final Map chunkMatchCount = scoredHashMatches.computeIfAbsent(ssdh, s -> new TreeMap<>()); + final Integer score = chunkMatchCount.computeIfAbsent(hash, m -> 0); + log.trace("Incrementing score for " + ssdh + "," + hash + " by " + cpc.size()); + chunkMatchCount.put(hash, score + 1); + }); + }); + }); + + // convert the counted chunks into tuples. + final Multimap scoreTuples = TreeMultimap.create(); + scoredHashMatches.forEach((sdh, cmc) -> cmc.forEach((k, v) -> scoreTuples.put(sdh, new NGramScoreTuple(k, v)))); + return scoreTuples; + } + + public Multimap getQueryMap() { + return queryMap; + } +} diff --git a/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/BucketAccumuloKeyGenerator.java b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/BucketAccumuloKeyGenerator.java new file mode 100644 index 00000000000..8019332454e --- /dev/null +++ b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/BucketAccumuloKeyGenerator.java @@ -0,0 +1,98 @@ +package datawave.query.util.ssdeep; + +import java.util.Arrays; + +import org.apache.accumulo.core.data.Key; +import org.apache.accumulo.core.data.Value; + +import datawave.query.util.Tuple2; + +/** + * Transforms NGram/SSDeep Pairs to Accumulo Key/Values. The approach toward generating rowIds produces prefixes for each indexed ngram that include a 'bucket' + * and an encoded version of the chunkSize, thus allowing independent portions of the index to be scanned in parallel. A bucket is an abstract concept that + * serves as a partitioning mechanism for the indexed ngram data. + *

+ * The bucket is chosen based on the hashCode of the original SSDeep bytes. As a result all ngrams from the same ssdeep hash will appear in the same bucket, but + * identical ngrams may be scattered across multiple buckets. As a result, it is important that a query strategy considers all buckets. + *

+ * In addition to encoding the bucket at the start of the rowId for the keys generated by this class, the rowId also encodes the chunk size of the ngram in the + * rowId immediately after the bucket prefix appears, within any given bucket, a query strategy can limit its ranges based on the desired buckets. + *

+ * After the encoded bucket and chunk size, the actual ngram from the ssdeep hash (as obtained from the input NGramTuple), is appended to the rowId in the keys + * generated. As such, the keys generated by this class are in the form: (bucket) + (encoded chunk size) + (ssdeep ngram). A bucket is typically 2 characters, + * An encoded chunk size is 1 character and a ssdeep ngram is based on the chosen ngram length (7 by default). + */ +public class BucketAccumuloKeyGenerator { + + public static final byte[] EMPTY_BYTES = new byte[0]; + public static final Value EMPTY_VALUE = new Value(); + + public static final int DEFAULT_BUCKET_COUNT = 32; + + /** The number of characters in the bucket encoding alphabet */ + public static final int DEFAULT_BUCKET_ENCODING_BASE = 32; + /** The length of the bucket encoding we will perform */ + public static final int DEFAULT_BUCKET_ENCODING_LENGTH = 2; + + /** The maximum number of buckets we will partition data into */ + final int bucketCount; + /** Used to encode the bucket id into a string of characters in a constrained alphabet that will go ito the rowId */ + final IntegerEncoding bucketEncoding; + /** Used to encode a chunk size into a string of characters in a constrained alphabet that will go into the rowId */ + final ChunkSizeEncoding chunkEncoding; + /** Used to encode the ngram bytes into a string of characters in a constrained alphabet */ + final SSDeepEncoding ngramEncoding; + /** The timestamp to use for the generated key */ + final long timestamp = 0; + + /** + * Creates a BucketAccumuloKeyGenerator with the specified bucket count and encoding properties + * + * @param bucketCount + * the number of index buckets (partitions) that will be used. + * @param bucketEncodingBase + * the size of the alphabet that will be used to encode the index bucket number in the key. + * @param bucketEncodingLength + * the number of characters that will be used to encode the index bucket number. + */ + public BucketAccumuloKeyGenerator(int bucketCount, int bucketEncodingBase, int bucketEncodingLength) { + this.bucketCount = bucketCount; + this.bucketEncoding = new IntegerEncoding(bucketEncodingBase, bucketEncodingLength); + this.chunkEncoding = new ChunkSizeEncoding(); + this.ngramEncoding = new SSDeepEncoding(); + + if (bucketCount > bucketEncoding.getLimit()) { + throw new IllegalArgumentException("Integer encoding limit is " + bucketEncoding.getLimit() + " but bucket count was larger: " + bucketCount); + } + } + + /** + * Given a (ngram / ssdeep byte) tuple produce an Accumulo Key/Value pair. The rowId is formed based on the structure discussed in this class' javadoc, the + * column family is the integer chunk size and the column qualifier is the original ssdeep hash bytes. The generated value is always empty. + * + * @param t + * @return + * @throws Exception + */ + public Tuple2 call(Tuple2 t) throws Exception { + int rowSize = t.first().getChunk().length() + bucketEncoding.getLength() + chunkEncoding.getLength(); + final byte[] row = new byte[rowSize]; + int pos = 0; + + // encode and write the bucket + final int bucket = Math.abs(Arrays.hashCode(t.second()) % bucketCount); + bucketEncoding.encodeToBytes(bucket, row, pos); + pos += bucketEncoding.getLength(); + + // encode and write the chunk size + chunkEncoding.encodeToBytes(t.first().getChunkSize(), row, pos); + pos += chunkEncoding.getLength(); + + // encode and write the ngram + ngramEncoding.encodeToBytes(t.first().getChunk(), row, pos); + + final byte[] cf = IntegerEncoding.encodeBaseTenDigitBytes(t.first().getChunkSize()); + final byte[] cq = t.second(); + return new Tuple2<>(new Key(row, cf, cq, EMPTY_BYTES, timestamp, false, false), EMPTY_VALUE); + } +} diff --git a/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/ChunkSizeEncoding.java b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/ChunkSizeEncoding.java new file mode 100644 index 00000000000..2545d2f34ae --- /dev/null +++ b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/ChunkSizeEncoding.java @@ -0,0 +1,89 @@ +package datawave.query.util.ssdeep; + +import java.io.Serializable; + +// @formatter:off + +/** The encoder exploits the fact that there are a small number of legal chunk sizes based on the minimum chunk size. + * It introduces the concept of a chunkIndex, a number that is considerably smaller than the chunk size itself, and + * represents the magnitude of the chunkSize such that: + *

+ * <pre> + * chunkSize = MIN_CHUNK_SIZE * 2^chunkIndex + *

+ * thus: + *

+ * chunkIndex = log2(chunkSize/MIN_CHUNK_SIZE) + * </pre> + *

+ * For further compression, we encode the chunkIndex as a base64 encoded string, represented using a single character + * from the Base64 alphabet. + *

+ * The encode/decode methods can handle chunkIndexes larger than 64 because the logic is there to handle up to three + * digit base64 encoded strings, but in practice due to the max possible value of the long type, the largest + * chunkIndex we see is 55, which maps to a chunk size of 108,086,391,056,891,904 bytes and a total file size of + * 6,917,529,027,641,081,856 bytes + */ +//@formatter:on +public class ChunkSizeEncoding implements Serializable { + + private static final int MIN_CHUNK_SIZE = 3; + private static final int DEFAULT_ENCODING_ALPHABET_LENGTH = HashReverse.LEXICAL_B64_TABLE.length; + + private static final int DEFAULT_ENCODING_LENGTH = 1; + + static final double L2 = Math.log(2); + + private final IntegerEncoding chunkIndexEncoding; + + final int minChunkSize; + + /** + * Create a ChunkSizeEncoding with the default parameters of a 64 character encoding alphabet and a length of 1. This allows us to encode 64 distinct chunk + * index values. Chunk index 0 represents the MIN_CHUNK_SIZE. See class javadocs for more info. + */ + public ChunkSizeEncoding() { + this(MIN_CHUNK_SIZE, DEFAULT_ENCODING_ALPHABET_LENGTH, DEFAULT_ENCODING_LENGTH); + } + + public ChunkSizeEncoding(int minChunkSize, int encodingAlphabetLength, int encodingLength) { + this.minChunkSize = minChunkSize; + this.chunkIndexEncoding = new IntegerEncoding(encodingAlphabetLength, encodingLength); + } + + public long getLimit() { + return findChunkSizeIndex(chunkIndexEncoding.getLimit()); + } + + public int getLength() { + return chunkIndexEncoding.getLength(); + } + + public long findNthChunkSize(int index) { + return minChunkSize * ((long) Math.pow(2, index)); + } + + public int findChunkSizeIndex(long chunkSize) { + return (int) (Math.log(chunkSize / (float) minChunkSize) / L2); + } + + public String encode(int chunkSize) { + int index = findChunkSizeIndex(chunkSize); + return chunkIndexEncoding.encode(index); + } + + public byte[] encodeToBytes(int chunkSize, byte[] buffer, int offset) { + int index = findChunkSizeIndex(chunkSize); + return chunkIndexEncoding.encodeToBytes(index, buffer, offset); + } + + public int decode(String encoded) { + int index = chunkIndexEncoding.decode(encoded); + return (int) findNthChunkSize(index); + } + + public int decode(byte[] encoded, int offset) { + int index = chunkIndexEncoding.decode(encoded, offset); + return (int) findNthChunkSize(index); + } +} diff --git a/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/HashReverse.java b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/HashReverse.java new file mode 100644 index 00000000000..53edc14dd2f --- /dev/null +++ b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/HashReverse.java @@ -0,0 +1,106 @@ +package datawave.query.util.ssdeep; + +import java.io.FileWriter; +import java.io.IOException; +import java.io.PrintWriter; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.TreeMap; + +import com.google.common.collect.ImmutableMap; + +/** + * Utility class that provides methods to map a ssdeep hash or ssdeep hash ngram to a position in a linear index. This is used when partitioning data so that + * items that would be adjacent lexically are written to the same partition. It is an alternative to hash partitioning. + */ +public class HashReverse { + /** Lookup table for the Base-64 encoding used in the SSDEEP Hashes, but sorted. 64 distinct values per place */ + public static final byte[] LEXICAL_B64_TABLE = getBytes("+/0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); + + /** A maps of characters in the hash encoding to their position in the lookup table */ + private static final Map reverseMap = new TreeMap<>(); + static { + for (int i = 0; i < LEXICAL_B64_TABLE.length; i++) { + reverseMap.put(LEXICAL_B64_TABLE[i], i); + } + } + public static final ImmutableMap REVERSE_LEXICAL_B64_MAP = ImmutableMap.copyOf(reverseMap); + + /** The smallest possible index value */ + public static final int MIN_VALUE = 0; + + /** Largest possible hash */ + public static final byte[] MAX_HASH = "zzzzz".getBytes(StandardCharsets.UTF_8); + + /** Anything larger than 64^5 will overflow an integer, so our prefix can't be larger */ + public static final int MAX_PREFIX_SIZE = 5; + + /** The largest possible index value */ + public static final int MAX_VALUE = getPrefixIndex(MAX_HASH, MAX_PREFIX_SIZE); + + public static byte[] getBytes(String str) { + byte[] r = new byte[str.length()]; + for (int i = 0; i < r.length; i++) { + r[i] = (byte) str.charAt(i); + } + return r; + } + + public static int getPrefixIndex(final String hash, final int length) { + return getPrefixIndex(hash.getBytes(StandardCharsets.UTF_8), length); + } + + /** + * Return the 'index' of the specified hash in the space of all possible hashes for the specified length. Thinking of the hash string as a 'number' with + * each position as a 'place', indexes are calculated by collecting the sum of 'value * 64^place', where value is derived based on the position of the + * character in the array of acceptable base64 characters. + * + * @param hash + * @param length + * @return + */ + public static int getPrefixIndex(final byte[] hash, final int length) { + int result = 0; + + final int limit = Math.min(hash.length, length); + if (limit > MAX_PREFIX_SIZE) { + throw new IndexOutOfBoundsException("Generating indexes for prefixes > 5 in length will lead to an integer overflow"); + } + + for (int i = 0; i < limit; i++) { + int place = (limit - i) - 1; + Integer value = REVERSE_LEXICAL_B64_MAP.get(hash[i]); + if (value == null) { + throw new SSDeepParseException("Character at offset " + i + " is out of range", hash); + } + result += Math.pow(64, place) * value; + } + + return result; + } + + /** + * Return the max possible value for the provided prefix length + * + * @param length + * @return + */ + public static int getPrefixMax(final int length) { + return getPrefixIndex(MAX_HASH, length); + } + + /** Utility to generate splits for the ssdeep table based on a prefix of 2 - 64^64 (4096) splits in size */ + public static void main(String[] args) throws IOException { + int len = LEXICAL_B64_TABLE.length; + byte[] output = new byte[2]; + try (FileWriter fw = new FileWriter("ssdeep-splits.txt"); PrintWriter writer = new PrintWriter(fw)) { + for (int i = 0; i < len; i++) { + output[0] = LEXICAL_B64_TABLE[i]; + for (int j = 0; j < len; j++) { + output[1] = LEXICAL_B64_TABLE[j]; + writer.println(new String(output)); + } + } + } + } +} diff --git a/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/IntegerEncoding.java b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/IntegerEncoding.java new file mode 100644 index 00000000000..6f111633198 --- /dev/null +++ b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/IntegerEncoding.java @@ -0,0 +1,160 @@ +package datawave.query.util.ssdeep; + +import java.io.Serializable; +import java.nio.charset.StandardCharsets; + +/** + * Class for encoding integers into a lexically sorted output of constant length. Employs the sorted Base64 alphabet captured in the HashReverse class. + */ +public class IntegerEncoding implements Serializable { + + // The number of distinct characters used for encoding + final int base; + // the target length of the encoding + final int length; + // the max integer value we can encode, derived from the base and length parameters. + final int limit; + + /** + * We are using the LEXICAL_B64_TABLE to encode integers to characters, our max base (the unique characters we use for encoding) is based on the size of + * this alphabet. + */ + private static final int MAX_BASE = HashReverse.LEXICAL_B64_TABLE.length; + + /** + * Create an unsigned integer encoder that uses the specified base (up to 64) and length (which can't generate numbers larger than Integer.MAX_VALUE). This + * uses the lexically sorted Base 64 alphabet for encoding. + * + * @param base + * base for encoding, this is the number of distinct characters that will be used to encode integers must be larger than 2, less than 64. + * @param length + * the length (in bytes) of the final encoding produced by this encoding + */ + public IntegerEncoding(int base, int length) { + if (base < 2 || base > 64) { + throw new IllegalArgumentException("Base must be between 2 and 64"); + } + if (length < 1) { + throw new IllegalArgumentException("Length must be greater than 0"); + } + this.base = base; + this.length = length; + double calculatedLimit = Math.pow(base, length); + if (calculatedLimit > Integer.MAX_VALUE) { + throw new IllegalArgumentException("Calculated limit " + calculatedLimit + " is larger than Integer.MAX_VALUE"); + } + this.limit = (int) calculatedLimit; // truncation is fine here. + } + + /** Return the maximum value this encoder can encode */ + public int getLimit() { + return limit; + } + + public int getLength() { + return length; + } + + /** Encode the provided value, return a string result */ + public String encode(int value) { + return new String(encodeToBytes(value, new byte[length], 0)); + } + + /** + * encode the provided value, writing the result to the provided buffer starting offset + * + * @param value + * the value to encode + * @param buffer + * the buffer to write to + * @param offset + * the offset to write into + * @return the buffer written to + */ + public byte[] encodeToBytes(int value, byte[] buffer, int offset) { + if (value < 0 || value >= limit) { + throw new IllegalArgumentException("Can't encode " + value + " is it out of range, max: " + limit + " was: " + value); + } + + if (buffer.length < offset + length) { + throw new IndexOutOfBoundsException("Can't encode a value of length " + length + " at offset " + offset + " buffer too small: " + buffer.length); + } + + int remaining = value; + for (int place = length; place > 0; place--) { + final int scale = ((int) Math.pow(base, place - 1)); + int pos = 0; + if (remaining >= scale) { + pos = remaining / scale; + remaining = remaining % scale; + } + buffer[offset + (length - place)] = HashReverse.LEXICAL_B64_TABLE[pos]; + } + return buffer; + } + + // TODO: make this just like encodeToBytes? + public static byte[] encodeBaseTenDigitBytes(int value) { + int remaining = value; + int digits = (int) Math.log10(remaining); + if (digits < 0) + digits = 0; + digits += 1; + // System.err.println(remaining + " " + digits); + byte[] results = new byte[digits]; + for (int place = digits - 1; place >= 0; place--) { + results[place] = (byte) ((remaining % 10) + 48); + remaining = remaining / 10; + } + return results; + } + + /** + * Decode the first _length_ characters in the encoded value into an integer, where length is specified in the constructor. + * + * @param encodedValue + * the string to decode + * @return the decoded result + */ + public int decode(String encodedValue) { + if (encodedValue.length() < length) { + throw new IllegalArgumentException("Encoded value is not the expected length, expected: " + length + ", was: " + encodedValue); + } + return decode(encodedValue.getBytes(StandardCharsets.UTF_8), 0); + } + + /** + * decode the value contained within the provided byte[] starting at the specified offset + * + * @param encoded + * the encoded integer + * @param offset + * the offset to read from in the input byte array + * @return the integer encoded at this place. + * @throws IndexOutOfBoundsException + * if the provided byte[] and offset doesn't provide sufficient space. + * @throws IllegalArgumentException + * if the byte[] contains an item that is not in range. + */ + public int decode(byte[] encoded, int offset) { + if (encoded.length < offset + length) { + throw new IndexOutOfBoundsException("Can't decode a value of length " + length + " from offset " + offset + " buffer too small: " + encoded.length); + } + + int result = 0; + for (int place = length; place > 0; place--) { + int pos = offset + (length - place); + Integer value = HashReverse.REVERSE_LEXICAL_B64_MAP.get(encoded[pos]); + if (value == null) { + throw new IllegalArgumentException("Character at offset " + pos + " is out of range '" + encoded[pos] + "'"); + } + result += (int) Math.pow(base, place - 1) * value; + } + + if (result > limit) { + throw new IllegalArgumentException("Can't decode input is it out of range, max: " + limit + " was: " + result); + } + + return result; + } +} diff --git a/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/NGramByteHashGenerator.java b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/NGramByteHashGenerator.java new file mode 100644 index 00000000000..b2533308e72 --- /dev/null +++ b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/NGramByteHashGenerator.java @@ -0,0 +1,23 @@ +package datawave.query.util.ssdeep; + +import java.util.Iterator; + +import datawave.query.util.Tuple2; + +/** + * Generates NGrams for the specified hash using the NGram Generator. Note: hashes may be normalized prior to n-gram-ing, but the non-normalized version of the + * hash is emitted in the second field of the Tuple emitted by this class. + */ +public class NGramByteHashGenerator { + final NGramGenerator nGramEngine; + final SSDeepEncoding ssDeepEncoder; + + public NGramByteHashGenerator(int size, int maxRepeatedCharacters, int minHashSize) { + nGramEngine = new NGramGenerator(size, maxRepeatedCharacters, minHashSize); + ssDeepEncoder = new SSDeepEncoding(); + } + + public Iterator> call(final String hash) throws Exception { + return nGramEngine.generateNgrams(hash).stream().map(g -> new Tuple2<>(g, ssDeepEncoder.encode(hash))).iterator(); + } +} diff --git a/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/NGramGenerator.java b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/NGramGenerator.java new file mode 100644 index 00000000000..86e0b3368cb --- /dev/null +++ b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/NGramGenerator.java @@ -0,0 +1,142 @@ +package datawave.query.util.ssdeep; + +import java.io.Serializable; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.Multimap; +import com.google.common.collect.TreeMultimap; + +/** Generates NGrams of SSDeep Hashes for indexing or query */ +public class NGramGenerator implements Serializable { + + private static final Logger log = LoggerFactory.getLogger(NGramGenerator.class); + + final int ngramSize; + final int maxRepeatedChars; + final int minHashSize; + + /** + * Generate NGrams of the specified size + * + * @param ngramSize + * the size of ngrams to generate. + */ + public NGramGenerator(int ngramSize) { + this(ngramSize, 0, 64); + } + + /** + * Generate NGrams of the specified size after normalizing to collapse repeated characters + * + * @param ngramSize + * the size of the ngrams to generate + * @param maxRepeatedChars + * the max number of repeated characters - uses normalization to replace any run of repeated characters longer than this with this many + * characters. If zero, no normalization will be performed + * @param minHashSize + * do not generate ngrams for hashes smaller than this size + */ + public NGramGenerator(int ngramSize, int maxRepeatedChars, int minHashSize) { + this.ngramSize = ngramSize; + this.maxRepeatedChars = maxRepeatedChars; + this.minHashSize = minHashSize; + } + + public int getNgramSize() { + return ngramSize; + } + + public int getMaxRepeatedChars() { + return maxRepeatedChars; + } + + public int getMinHashSize() { + return minHashSize; + } + + /** + * + * @param queries + * expected to be a collection of SSDeep hashes in chunkSize:chunk:doubleChunk format + * @return a multimap of NGramTuples mapped to the SSDeepHash from which they originated. + */ + public Multimap preprocessQueries(Collection queries) { + Multimap queryMap = TreeMultimap.create(); + + for (SSDeepHash queryHash : queries) { + generateNgrams(queryHash).forEach(t -> queryMap.put(t, queryHash)); + } + + return queryMap; + } + + /** + * @param ssDeepHashString + * expected to be an SSDeep hash in chunkSize:chunk:doubleChunk format. This will normalize the hash by removing repeated characters if + * maxRepeatedChars is greater than zero. + * @return a collection of NGramTuples that includes ngrams generated from both the chunk and doubleChunk portions of the input ssdeep hash. If the ssdeep + * can't be parsed this method will catch and log the parse exception. + */ + public Set generateNgrams(String ssDeepHashString) { + try { + return generateNgrams(SSDeepHash.parseAndNormalize(ssDeepHashString, maxRepeatedChars), minHashSize); + } catch (SSDeepParseException ex) { + log.debug(ex.getMessage()); + } + return Collections.emptySet(); + } + + public Set generateNgrams(SSDeepHash ssDeepHash) { + return this.generateNgrams(ssDeepHash, 0); + } + + /** + * + * @param ssDeepHash + * expected to be an SSDeep hash in chunkSize:chunk:doubleChunk format. Assumes that no normalization will be performed on the SSDeepHash (or it + * has already been performed, e.g., repeated characters have been collapsed already). + * @param minHashSize + * the minimum size (chunkSize * chunkLength) required for input hashes. We will not generate ngrams for hashes smaller than this. If set to + * zero, we will generate ngrams for all hashes regardless of length. + * @return a collection of NGramTuples that includes ngrams generated from both the chunk and doubleChunk portions of the ssdeep hash. + */ + public Set generateNgrams(SSDeepHash ssDeepHash, int minHashSize) { + final Set queryNgrams = new HashSet<>(); + + final int hashSize = ssDeepHash.getChunkSize() * ssDeepHash.getChunk().length(); + if (minHashSize > 0 && hashSize < minHashSize) { + log.debug("Skipping {}, SSDeep Hash Size {} is less than minimum {}", ssDeepHash, hashSize, minHashSize); + } else { + generateNgrams(ssDeepHash.getChunkSize(), ssDeepHash.getChunk(), queryNgrams); + + if (ssDeepHash.hasDoubleChunk()) { + generateNgrams(ssDeepHash.getDoubleChunkSize(), ssDeepHash.getDoubleChunk(), queryNgrams); + } + } + return queryNgrams; + } + + /** + * Generate SSDeep ngrams of size ngramSize and store them in the provided output collection. + * + * @param chunkSize + * the chunkSize that corresponds to the chunk for which we are generating ngrams. This will be encoded into the output ngram tuple. + * @param chunk + * the chunk of the ssdeep hash for which we are generating ngrams + * @param output + * a collection that is used to collect NGramTuples. These NGramTuples capture the chunk size and ngrams of the input chunk. + */ + public void generateNgrams(int chunkSize, String chunk, Set output) { + final int ngramCount = chunk.length() - ngramSize; + for (int i = 0; i < ngramCount; i++) { + final String ngram = chunk.substring(i, i + ngramSize); + output.add(new NGramTuple(chunkSize, ngram)); + } + } +} diff --git a/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/NGramScoreTuple.java b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/NGramScoreTuple.java new file mode 100644 index 00000000000..ebb5905ede5 --- /dev/null +++ b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/NGramScoreTuple.java @@ -0,0 +1,51 @@ +package datawave.query.util.ssdeep; + +import java.io.Serializable; +import java.util.Objects; + +public class NGramScoreTuple implements Serializable, Comparable { + final SSDeepHash ssDeepHash; + final float score; + + public NGramScoreTuple(SSDeepHash ssDeepHash, float score) { + this.ssDeepHash = ssDeepHash; + this.score = score; + } + + public SSDeepHash getSsDeepHash() { + return ssDeepHash; + } + + public float getScore() { + return score; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof NGramScoreTuple)) + return false; + NGramScoreTuple that = (NGramScoreTuple) o; + return ssDeepHash == that.ssDeepHash && Float.compare(that.score, score) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(ssDeepHash, score); + } + + @Override + public String toString() { + return "ScoreTuple{" + "hash=" + ssDeepHash + ", score=" + score + '}'; + } + + @Override + public int compareTo(NGramScoreTuple o) { + int cmp = Float.compare(o.score, score); + if (cmp == 0) { + cmp = ssDeepHash.compareTo(o.ssDeepHash); + } + return cmp; + } +} diff --git a/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/NGramTuple.java b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/NGramTuple.java new file mode 100644 index 00000000000..6088e535964 --- /dev/null +++ b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/NGramTuple.java @@ -0,0 +1,68 @@ +package datawave.query.util.ssdeep; + +import java.io.Serializable; +import java.util.Objects; + +/** + * Embodies a Base64 encoded SSDeep hash chunk, and an accompanying chunk size for that hash. Per the SSDeep specification, each character in the hash chunk + * corresponds to a set of bytes of chunkSize in the original binary object. + * + * Practically, this can be used to store either an entire SSDEEP hash chunk or a substring/ngram of that chunk. + */ +public class NGramTuple implements Serializable, Comparable { + + public static final String CHUNK_DELIMITER = ":"; + + final int chunkSize; + final String chunk; + + public NGramTuple(int chunkSize, String chunk) { + this.chunk = chunk; + this.chunkSize = chunkSize; + } + + public static NGramTuple parse(String tuple) { + int pos = tuple.indexOf(CHUNK_DELIMITER); + String chunkSizeString = tuple.substring(0, pos); + String chunk = tuple.substring(pos + 1); + int chunkSize = Integer.parseInt(chunkSizeString); + return new NGramTuple(chunkSize, chunk); + } + + public String getChunk() { + return chunk; + } + + public int getChunkSize() { + return chunkSize; + } + + public String toString() { + return String.join(":", String.valueOf(chunkSize), chunk); + } + + @Override + public int compareTo(NGramTuple o) { + int cmp = Integer.compare(this.chunkSize, o.chunkSize); + if (cmp == 0) { + return this.chunk.compareTo(o.chunk); + } else { + return cmp; + } + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof NGramTuple)) + return false; + NGramTuple that = (NGramTuple) o; + return chunkSize == that.chunkSize && chunk.equals(that.chunk); + } + + @Override + public int hashCode() { + return Objects.hash(chunk, chunkSize); + } +} diff --git a/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/SSDeepEncoding.java b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/SSDeepEncoding.java new file mode 100644 index 00000000000..a067bd200a7 --- /dev/null +++ b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/SSDeepEncoding.java @@ -0,0 +1,21 @@ +package datawave.query.util.ssdeep; + +import java.io.Serializable; + +/** + * Simple converter from an SSDeep string that includes only characters from the Base64 alphabet to a byte array. We can cast the Base64 letters directly into + * bytes without needing to do the complex operations that Java must usually do when converting strings to bytes because we do not have to handle multibyte + * characters here. As a result this implementation is more performant than alternatives built into java like 'String.getBytes()'. + */ +public class SSDeepEncoding implements Serializable { + public byte[] encode(String ngram) { + return encodeToBytes(ngram, new byte[ngram.length()], 0); + } + + public byte[] encodeToBytes(String ngram, byte[] buffer, int offset) { + for (int i = 0; i < ngram.length(); i++) { + buffer[i + offset] = (byte) ngram.charAt(i); + } + return buffer; + } +} diff --git a/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/SSDeepHash.java b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/SSDeepHash.java new file mode 100644 index 00000000000..0d398357e1a --- /dev/null +++ b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/SSDeepHash.java @@ -0,0 +1,249 @@ +package datawave.query.util.ssdeep; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.io.Serializable; +import java.util.Objects; + +/** An Immutable SSDeepHash object */ +public final class SSDeepHash implements Serializable, Comparable { + + /** The default number of max repeated characters to produce when normalizing hashes */ + public static final int DEFAULT_MAX_REPEATED_CHARACTERS = 3; + + public static final int MIN_CHUNK_SIZE = 3; + public static final int CHUNK_MULTIPLE = 3; + public static final String CHUNK_DELIMITER = ":"; + public static final int CHUNK_LENGTH = 64; + public static final int DOUBLE_CHUNK_LENGTH = 32; + + final int chunkSize; + final String chunk; + final boolean hasDoubleChunk; + final String doubleChunk; + + public SSDeepHash(int chunkSize, String chunk, String doubleChunk) { + if (chunkSize < MIN_CHUNK_SIZE) { + throw new IllegalArgumentException("chunkSize was " + chunkSize + " but must no less than " + MIN_CHUNK_SIZE); + } else if (chunkSize % CHUNK_MULTIPLE != 0) { + throw new IllegalArgumentException("chunkSize was " + chunkSize + " but must be a multiple of three that is a power of 2"); + } else if (Integer.bitCount(chunkSize / CHUNK_MULTIPLE) != 1) { + throw new IllegalArgumentException("chunkSize (" + chunkSize + ") / " + CHUNK_MULTIPLE + " must be a power of 2"); + } else if (chunk.length() > CHUNK_LENGTH) { + throw new IllegalArgumentException("chunk length must be less than " + CHUNK_LENGTH); + } + + if (doubleChunk.isEmpty()) { + this.hasDoubleChunk = false; + } else if (doubleChunk.length() > DOUBLE_CHUNK_LENGTH) { + throw new IllegalArgumentException("double chunk length must be less than " + DOUBLE_CHUNK_LENGTH); + } else { + this.hasDoubleChunk = true; + } + + // TODO: We can make additional assertions, e.g.: that the chunk and doubleChunk are base64 encoded. + + this.chunkSize = chunkSize; + this.chunk = chunk; + this.doubleChunk = doubleChunk; + } + + public int getChunkSize() { + return chunkSize; + } + + public String getChunk() { + return chunk; + } + + public boolean hasDoubleChunk() { + return hasDoubleChunk; + } + + public int getDoubleChunkSize() { + return hasDoubleChunk ? chunkSize * 2 : 0; + } + + public String getDoubleChunk() { + return doubleChunk; + } + + @Override + public int compareTo(SSDeepHash o) { + int cmp = chunkSize - o.chunkSize; + if (cmp == 0) { + cmp = chunk.compareTo(o.chunk); + } + if (cmp == 0) { + cmp = doubleChunk.compareTo(o.doubleChunk); + } + return cmp; + } + + public void serialize(DataOutput oos) throws IOException { + oos.writeUTF(toString()); + } + + public static SSDeepHash deserialize(DataInput ois) throws IOException { + return SSDeepHash.parse(ois.readUTF()); + } + + /* + * TODO: remove unused methods public static byte[] serialize(SSDeepHash hash) { final ByteArrayOutputStream bos = new ByteArrayOutputStream(); try + * (ObjectOutputStream oos = new ObjectOutputStream(bos)) { hash.serialize(oos); + * + * } catch (IOException ioe) { log.error("Exception serializing postings", ioe); } return bos.toByteArray(); } + * + * public static SSDeepHash deserialize(byte[] ssDeepHashBytes) { final ByteArrayInputStream bis = new ByteArrayInputStream(ssDeepHashBytes); try + * (ObjectInputStream ois = new ObjectInputStream(bis)) { return SSDeepHash.deserialize(ois); } catch (IOException ioe) { + * log.error("Exception deserializing ssdeep hash", ioe); } return null; } + */ + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof SSDeepHash)) + return false; + SSDeepHash that = (SSDeepHash) o; + return chunkSize == that.chunkSize && chunk.equals(that.chunk) && doubleChunk.equals(that.doubleChunk); + } + + @Override + public int hashCode() { + return Objects.hash(chunkSize, chunk, doubleChunk); + } + + @Override + public String toString() { + return String.join(CHUNK_DELIMITER, String.valueOf(chunkSize), chunk, doubleChunk); + } + + /** + * Parse a string of chunkSize:chunk:doubleChunk into a SSDeepHash. doubleChunk is optional and if it does not exist, we will use an empty String in its + * place. + * + * @param ssdeepHash + * the string to parse + * @return the constructed ssdeep hash object. + * @throws SSDeepParseException + * if the ssdeepHash string is not in the expected format. + */ + public static SSDeepHash parse(String ssdeepHash) throws SSDeepParseException { + final String[] parts = ssdeepHash.split(CHUNK_DELIMITER); // possible NPE. + if (parts.length < 2 || parts.length > 3) { + throw new SSDeepParseException("Could not parse SSDeepHash, expected 2 or 3 '" + CHUNK_DELIMITER + "'-delimited segments observed " + parts.length, + ssdeepHash); + } + + try { + final String doubleChunk = (parts.length == 3) ? parts[2] : ""; + return new SSDeepHash(Integer.parseInt(parts[0]), parts[1], doubleChunk); + } catch (NumberFormatException nfe) { + throw new SSDeepParseException("Could not parse SSDeepHash, expected first segment to be an integer", ssdeepHash); + } + } + + public static SSDeepHash normalize(final SSDeepHash input) { + return SSDeepHash.normalize(input, DEFAULT_MAX_REPEATED_CHARACTERS); + } + + /** + * Normalize each chunk in an SSDeepHash by removing strings of repeated characters and replacing them with a string of the same characters that is + * maxRepeatedCharacters in length. This reduces useless variation that consumes space in the SSDeepHashes. If the string contains no runs of repeated + * characters longer than maxRepeatedCharacters, the original SSDeepHash is returned. + * + * @param input + * the SSDeepHash to normalize. + * @param maxRepeatedCharacters + * the maximum number of repeated characters + * @return a new SSDeepHash with normalized chunks or the original SSDeepHash if no changes were made. + */ + public static SSDeepHash normalize(final SSDeepHash input, int maxRepeatedCharacters) { + final String n1 = normalizeSSDeepChunk(input.getChunk(), maxRepeatedCharacters); + final String n2 = normalizeSSDeepChunk(input.getDoubleChunk(), maxRepeatedCharacters); + if (n1 == null && n2 == null) { + return input; + } + return new SSDeepHash(input.getChunkSize(), n1 == null ? input.getChunk() : n1, n2 == null ? input.getDoubleChunk() : n2); + } + + public SSDeepHash normalize(int maxRepeatedCharacters) { + return normalize(this, maxRepeatedCharacters); + } + + /** + * Given a string that potentially contains long runs of repeating characters, replace such runs with at most maxRepeated characters. If the string is not + * modified, return null. + * + * @param input + * the string to analyze and possibly modify. + * @param maxRepeatedCharacters + * the number of maxRepeatedCharacters to allow. Any String that has a run of more than this many of the same character will have that run + * collapsed to be this many characters in length. Zero indicates that no normalization should be performed. + * @return the modified string or null if the string is not modified. + */ + public static String normalizeSSDeepChunk(final String input, final int maxRepeatedCharacters) { + if (maxRepeatedCharacters <= 0) { + return null; // do nothing. + } + final char[] data = input.toCharArray(); + final int length = data.length; + + int repeatedCharacters = 1; // number of consecutive characters observed + int sourceIndex = 0; + int destIndex = 0; + + // visit each position of the source string tracking runs of consecutive characters in + // 'consecutiveChars'. + for (; sourceIndex < length; sourceIndex++) { + if (sourceIndex < (length - 1) && data[sourceIndex] == data[sourceIndex + 1]) { + repeatedCharacters++; + } else { + repeatedCharacters = 1; // reset consecutive character counter. + } + + // if we see more than maxConsecutiveChars consecutive characters, we will + // skip them. Otherwise, we leave the data alone. If we have skipped characters + // we need to copy them subsequent characters to a new position. + if (repeatedCharacters <= maxRepeatedCharacters) { + if (destIndex < sourceIndex) { + data[destIndex] = data[sourceIndex]; + } + destIndex++; + } + } + + // if we have modified the data, create and return a string otherwise, null + if (destIndex < length) { + return new String(data, 0, destIndex); + } else { + return null; + } + } + + public static SSDeepHash parseAndNormalize(String ssdeepHash, int maxRepeatedChars) throws SSDeepParseException { + final String[] parts = ssdeepHash.split(CHUNK_DELIMITER); // possible NPE. + if (parts.length < 2 || parts.length > 3) { + throw new SSDeepParseException("Could not parse SSDeepHash, expected 2 or 3 '" + CHUNK_DELIMITER + "'-delimited segments observed " + parts.length, + ssdeepHash); + } + + try { + final int chunkSize = Integer.parseInt(parts[0]); + final String chunk = parts[1]; + final String doubleChunk = (parts.length == 3) ? parts[2] : ""; + + final String chunkNorm = normalizeSSDeepChunk(chunk, maxRepeatedChars); + final String doubleChunkNorm = normalizeSSDeepChunk(doubleChunk, maxRepeatedChars); + + // @formatter: off + return new SSDeepHash(chunkSize, chunkNorm == null ? chunk : chunkNorm, doubleChunkNorm == null ? doubleChunk : doubleChunkNorm); + // @formatter: on + + } catch (NumberFormatException nfe) { + throw new SSDeepParseException("Could not parse SSDeepHash, expected first segment to be an integer", ssdeepHash); + } + } +} diff --git a/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/SSDeepHashParser.java b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/SSDeepHashParser.java new file mode 100644 index 00000000000..0d6402b17f7 --- /dev/null +++ b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/SSDeepHashParser.java @@ -0,0 +1,21 @@ +package datawave.query.util.ssdeep; + +import java.util.Collections; +import java.util.Iterator; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SSDeepHashParser { + + private static final Logger log = LoggerFactory.getLogger(SSDeepHashParser.class); + + public Iterator call(String s) throws Exception { + try { + return Collections.singletonList(SSDeepHash.parse(s)).iterator(); + } catch (Exception e) { + log.info("Could not parse SSDeepHash: '" + s + "'"); + } + return Collections.emptyIterator(); + } +} diff --git a/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/SSDeepParseException.java b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/SSDeepParseException.java new file mode 100644 index 00000000000..e883ba385a6 --- /dev/null +++ b/warehouse/query-core/src/main/java/datawave/query/util/ssdeep/SSDeepParseException.java @@ -0,0 +1,23 @@ +package datawave.query.util.ssdeep; + +import java.nio.charset.StandardCharsets; + +public class SSDeepParseException extends RuntimeException { + final String message; + final String input; + + public SSDeepParseException(String message, byte[] input) { + this.message = message; + this.input = new String(input, StandardCharsets.UTF_8); + } + + public SSDeepParseException(String message, String input) { + this.message = message; + this.input = input; + } + + @Override + public String toString() { + return "SSDeepParseException{" + "message='" + message + '\'' + ", input='" + input + '\'' + '}'; + } +} diff --git a/warehouse/query-core/src/test/java/datawave/query/SSDeepQueryTest.java b/warehouse/query-core/src/test/java/datawave/query/SSDeepQueryTest.java new file mode 100644 index 00000000000..bb932abe2f8 --- /dev/null +++ b/warehouse/query-core/src/test/java/datawave/query/SSDeepQueryTest.java @@ -0,0 +1,205 @@ +package datawave.query; + +import static org.junit.Assert.fail; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Stream; + +import org.apache.accumulo.core.client.AccumuloClient; +import org.apache.accumulo.core.client.BatchWriter; +import org.apache.accumulo.core.client.BatchWriterConfig; +import org.apache.accumulo.core.client.Scanner; +import org.apache.accumulo.core.client.TableNotFoundException; +import org.apache.accumulo.core.client.admin.TableOperations; +import org.apache.accumulo.core.data.Key; +import org.apache.accumulo.core.data.Mutation; +import org.apache.accumulo.core.data.Value; +import org.apache.accumulo.core.security.Authorizations; +import org.apache.accumulo.core.security.ColumnVisibility; +import org.apache.commons.collections4.iterators.TransformIterator; +import org.apache.log4j.Logger; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import com.google.common.collect.Sets; + +import datawave.accumulo.inmemory.InMemoryAccumuloClient; +import datawave.accumulo.inmemory.InMemoryInstance; +import datawave.marking.MarkingFunctions; +import datawave.microservice.querymetric.QueryMetricFactoryImpl; +import datawave.query.tables.SSDeepSimilarityQueryLogic; +import datawave.query.testframework.AbstractDataTypeConfig; +import datawave.query.transformer.SSDeepSimilarityQueryTransformer; +import datawave.query.util.Tuple2; +import datawave.query.util.ssdeep.BucketAccumuloKeyGenerator; +import datawave.query.util.ssdeep.NGramByteHashGenerator; +import datawave.query.util.ssdeep.NGramTuple; +import datawave.security.authorization.DatawavePrincipal; +import datawave.security.authorization.DatawaveUser; +import datawave.security.authorization.SubjectIssuerDNPair; +import datawave.webservice.common.connection.AccumuloConnectionFactory; +import datawave.webservice.query.QueryImpl; +import datawave.webservice.query.result.event.DefaultResponseObjectFactory; +import datawave.webservice.query.result.event.EventBase; +import datawave.webservice.query.result.event.FieldBase; +import datawave.webservice.query.runner.RunningQuery; +import datawave.webservice.result.EventQueryResponseBase; + +public class SSDeepQueryTest { + + public static String[] TEST_SSDEEPS = {"12288:002r/VG4GjeZHkwuPikQ7lKH5p5H9x1beZHkwulizQ1lK55pGxlXTd8zbW:002LVG4GjeZEXi37l6Br1beZEdic1lmu", + "6144:02C3nq73v1kHGhs6y7ppFj93NRW6/ftZTgC6e8o4toHZmk6ZxoXb0ns:02C4cGCLjj9Swfj9koHEk6/Fns", + "3072:02irbxzGAFYDMxud7fKg3dXVmbOn5u46Kjnz/G8VYrs123D6pIJLIOSP:02MKlWQ7Sg3d4bO968rm7JO", + "48:1aBhsiUw69/UXX0x0qzNkVkydf2klA8a7Z35:155w69MXAlNkmkWTF5", "196608:wEEE+EEEEE0LEEEEEEEEEEREEEEhEEETEEEEEWUEEEJEEEEcEEEEEEEE3EEEEEEN:", + "1536:0YgNvw/OmgPgiQeI+25Nh6+RS5Qa8LmbyfAiIRgizy1cBx76UKYbD+iD/RYgNvw6:", "12288:222222222222222222222222222222222:"}; + + private static final Logger log = Logger.getLogger(SSDeepQueryTest.class); + + private static final Authorizations auths = AbstractDataTypeConfig.getTestAuths(); + + protected static AccumuloClient accumuloClient; + + protected SSDeepSimilarityQueryLogic logic; + + protected DatawavePrincipal principal; + + public static final int BUCKET_COUNT = BucketAccumuloKeyGenerator.DEFAULT_BUCKET_COUNT; + public static final int BUCKET_ENCODING_BASE = BucketAccumuloKeyGenerator.DEFAULT_BUCKET_ENCODING_BASE; + public static final int BUCKET_ENCODING_LENGTH = BucketAccumuloKeyGenerator.DEFAULT_BUCKET_ENCODING_LENGTH; + + public static void indexSSDeepTestData(AccumuloClient accumuloClient) throws Exception { + // configuration + String ssdeepTableName = "ssdeepIndex"; + int ngramSize = 7; + int minHashSize = 3; + + // input + Stream ssdeepLines = Stream.of(TEST_SSDEEPS); + + // processing + final NGramByteHashGenerator nGramGenerator = new NGramByteHashGenerator(ngramSize, BUCKET_COUNT, minHashSize); + final BucketAccumuloKeyGenerator accumuloKeyGenerator = new BucketAccumuloKeyGenerator(BUCKET_COUNT, BUCKET_ENCODING_BASE, BUCKET_ENCODING_LENGTH); + + // output + BatchWriterConfig batchWriterConfig = new BatchWriterConfig(); + final BatchWriter bw = accumuloClient.createBatchWriter(ssdeepTableName, batchWriterConfig); + + // operations + ssdeepLines.forEach(s -> { + try { + Iterator> it = nGramGenerator.call(s); + while (it.hasNext()) { + Tuple2 nt = it.next(); + Tuple2 at = accumuloKeyGenerator.call(nt); + Key k = at.first(); + Mutation m = new Mutation(k.getRow()); + ColumnVisibility cv = new ColumnVisibility(k.getColumnVisibility()); + m.put(k.getColumnFamily(), k.getColumnQualifier(), cv, k.getTimestamp(), at.second()); + bw.addMutation(m); + } + bw.flush(); + } catch (Exception e) { + log.error("Exception loading ssdeep hashes", e); + fail("Exception while loading ssdeep hashes: " + e.getMessage()); + } + }); + + bw.close(); + } + + @BeforeClass + public static void loadData() throws Exception { + final String tableName = "ssdeepIndex"; + + InMemoryInstance i = new InMemoryInstance("ssdeepTestInstance"); + accumuloClient = new InMemoryAccumuloClient("root", i); + + /* create the table */ + TableOperations tops = accumuloClient.tableOperations(); + if (tops.exists(tableName)) { + tops.delete(tableName); + } + tops.create(tableName); + + /* add ssdeep data to the table */ + indexSSDeepTestData(accumuloClient); + + /* dump the table */ + logSSDeepTestData(tableName); + } + + private static void logSSDeepTestData(String tableName) throws TableNotFoundException { + Scanner scanner = accumuloClient.createScanner(tableName, auths); + Iterator> iterator = scanner.iterator(); + log.debug("*************** " + tableName + " ********************"); + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + log.debug(entry); + } + scanner.close(); + } + + @Before + public void setUpQuery() { + logic = new SSDeepSimilarityQueryLogic(); + logic.setTableName("ssdeepIndex"); + logic.setMarkingFunctions(new MarkingFunctions.Default()); + logic.setResponseObjectFactory(new DefaultResponseObjectFactory()); + logic.setBucketEncodingBase(BUCKET_ENCODING_BASE); + logic.setBucketEncodingLength(BUCKET_ENCODING_LENGTH); + logic.setIndexBuckets(BUCKET_COUNT); + + SubjectIssuerDNPair dn = SubjectIssuerDNPair.of("userDn", "issuerDn"); + DatawaveUser user = new DatawaveUser(dn, DatawaveUser.UserType.USER, Sets.newHashSet(auths.toString().split(",")), null, null, -1L); + principal = new DatawavePrincipal(Collections.singleton(user)); + } + + @Test + public void testSingleQuery() throws Exception { + String query = "CHECKSUM_SSDEEP:" + TEST_SSDEEPS[2]; + EventQueryResponseBase response = runSSDeepQuery(query); + List events = response.getEvents(); + int eventCount = events.size(); + + Map observedFields = new HashMap<>(); + if (eventCount > 0) { + for (EventBase e : events) { + List fields = e.getFields(); + for (FieldBase f : fields) { + observedFields.put(f.getName(), f.getValueString()); + } + } + } + + Assert.assertFalse("Observed fields was unexpectedly empty", observedFields.isEmpty()); + Assert.assertEquals("65.0", observedFields.remove("MATCH_SCORE")); + Assert.assertEquals("1", observedFields.remove("MATCH_RANK")); + Assert.assertEquals("3072:02irbxzGAFYDMxud7fKg3dXVmbOn5u46Kjnz/G8VYrs123D6pIJLIOSP:02MKlWQ7Sg3d4bO968rm7JO", observedFields.remove("QUERY_SSDEEP")); + Assert.assertEquals("3072:02irbxzGAFYDMxud7fKg3dXVmbOn5u46Kjnz/G8VYrs123D6pIJLIOSP:02MKlWQ7Sg3d4bO968rm7JO", observedFields.remove("MATCHING_SSDEEP")); + Assert.assertTrue("Observed unexpected field(s): " + observedFields, observedFields.isEmpty()); + Assert.assertEquals(1, eventCount); + } + + public EventQueryResponseBase runSSDeepQuery(String query) throws Exception { + QueryImpl q = new QueryImpl(); + q.setQuery(query); + q.setId(UUID.randomUUID()); + q.setPagesize(Integer.MAX_VALUE); + q.setQueryAuthorizations(auths.toString()); + + RunningQuery runner = new RunningQuery(accumuloClient, AccumuloConnectionFactory.Priority.NORMAL, this.logic, q, "", principal, + new QueryMetricFactoryImpl()); + TransformIterator transformIterator = runner.getTransformIterator(); + SSDeepSimilarityQueryTransformer transformer = (SSDeepSimilarityQueryTransformer) transformIterator.getTransformer(); + EventQueryResponseBase response = (EventQueryResponseBase) transformer.createResponse(runner.next()); + + return response; + } +} diff --git a/warehouse/query-core/src/test/java/datawave/query/transformer/SSDeepSimilarityQueryTransformerTest.java b/warehouse/query-core/src/test/java/datawave/query/transformer/SSDeepSimilarityQueryTransformerTest.java new file mode 100644 index 00000000000..4cf1042f585 --- /dev/null +++ b/warehouse/query-core/src/test/java/datawave/query/transformer/SSDeepSimilarityQueryTransformerTest.java @@ -0,0 +1,96 @@ +package datawave.query.transformer; + +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.accumulo.core.data.Key; +import org.apache.accumulo.core.data.Value; +import org.easymock.EasyMock; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.api.easymock.annotation.Mock; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.modules.junit4.PowerMockRunner; + +import com.google.common.collect.Multimap; +import com.google.common.collect.TreeMultimap; + +import datawave.marking.MarkingFunctions; +import datawave.query.config.SSDeepSimilarityQueryConfiguration; +import datawave.query.tables.SSDeepSimilarityQueryLogic; +import datawave.query.util.ssdeep.ChunkSizeEncoding; +import datawave.query.util.ssdeep.IntegerEncoding; +import datawave.query.util.ssdeep.NGramTuple; +import datawave.query.util.ssdeep.SSDeepHash; +import datawave.webservice.query.Query; +import datawave.webservice.query.result.event.DefaultEvent; +import datawave.webservice.query.result.event.DefaultField; +import datawave.webservice.query.result.event.ResponseObjectFactory; +import datawave.webservice.result.BaseQueryResponse; +import datawave.webservice.result.DefaultEventQueryResponse; + +@RunWith(PowerMockRunner.class) +@PowerMockIgnore({"javax.management.*", "javax.xml.*"}) +public class SSDeepSimilarityQueryTransformerTest { + @Mock + private Query mockQuery; + + @Mock + private MarkingFunctions mockMarkingFunctions; + + @Mock + private ResponseObjectFactory mockResponseFactory; + + String chunk = "//thPkK"; + int chunkSize = 3; + String ssdeepString = "3:yionv//thPkKlDtn/rXScG2/uDlhl2UE9FQEul/lldDpZflsup:6v/lhPkKlDtt/6TIPFQEqRDpZ+up"; + + public void basicExpects(Key k) { + EasyMock.expect(mockQuery.getQueryAuthorizations()).andReturn("A,B,C"); + EasyMock.expect(mockResponseFactory.getEventQueryResponse()).andReturn(new DefaultEventQueryResponse()); + EasyMock.expect(mockResponseFactory.getEvent()).andReturn(new DefaultEvent()).times(1); + EasyMock.expect(mockResponseFactory.getField()).andReturn(new DefaultField()).times(4); + } + + @Test + public void transformTest() { + int bucketEncodingBase = 32; + int bucketEncodingLength = 2; + + NGramTuple tuple = new NGramTuple(chunkSize, chunk); + SSDeepHash hash = SSDeepHash.parse(ssdeepString); + + Multimap queryMap = TreeMultimap.create(); + queryMap.put(tuple, hash); + + Key key = new Key("+++//thPkK", "3", "3:yionv//thPkKlDtn/rXScG2/uDlhl2UE9FQEul/lldDpZflsup:6v/lhPkKlDtt/6TIPFQEqRDpZ+up"); + Value value = new Value(); + AbstractMap.SimpleEntry entry = new AbstractMap.SimpleEntry<>(key, value); + + SSDeepSimilarityQueryConfiguration config = SSDeepSimilarityQueryConfiguration.create(); + config.setBucketEncodingBase(bucketEncodingBase); + config.setBucketEncodingLength(bucketEncodingLength); + config.setQueryMap(queryMap); + + basicExpects(key); + + PowerMock.replayAll(); + + SSDeepSimilarityQueryTransformer transformer = new SSDeepSimilarityQueryTransformer(mockQuery, config, mockMarkingFunctions, mockResponseFactory); + Map.Entry transformedTuple = transformer.transform(entry); + List resultList = new ArrayList<>(); + resultList.add(transformedTuple); + BaseQueryResponse response = transformer.createResponse(resultList); + + PowerMock.verifyAll(); + + Assert.assertNotNull(transformedTuple); + Assert.assertEquals(hash, transformedTuple.getKey()); + Assert.assertEquals(tuple, transformedTuple.getValue()); + } +} diff --git a/web-services/deploy/configuration/src/main/resources/datawave/query/SSDeepQueryLogicFactory.xml b/web-services/deploy/configuration/src/main/resources/datawave/query/SSDeepQueryLogicFactory.xml new file mode 100644 index 00000000000..628b5bddee3 --- /dev/null +++ b/web-services/deploy/configuration/src/main/resources/datawave/query/SSDeepQueryLogicFactory.xml @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + +