Skip to content

Commit

Permalink
Add zero-shot classification w/ bart-large-mnli (huggingface#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Mar 27, 2023
1 parent 4e96788 commit 230984e
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 1 deletion.
6 changes: 5 additions & 1 deletion scripts/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
],
'facebook/bart-large-cnn': [
'seq2seq-lm-with-past'
]
],
'facebook/bart-large-mnli': [
'default',
'sequence-classification',
],
},
'bert': {
'bert-base-uncased': [
Expand Down
9 changes: 9 additions & 0 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,13 @@ class BartForConditionalGeneration extends BartPretrainedModel {
}
}

class BartForSequenceClassification extends BartPretrainedModel {
async _call(model_inputs) {
let logits = (await super._call(model_inputs)).logits;
return new SequenceClassifierOutput(logits)
}
}

//////////////////////////////////////////////////

//////////////////////////////////////////////////
Expand Down Expand Up @@ -1275,6 +1282,8 @@ class AutoModelForSequenceClassification {
return new DistilBertForSequenceClassification(config, session);
case 'roberta':
return new RobertaForSequenceClassification(config, session);
case 'bart':
return new BartForSequenceClassification(config, session);

default:
throw Error(`Unsupported model type: ${config.model_type}`)
Expand Down
77 changes: 77 additions & 0 deletions src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,73 @@ class TextGenerationPipeline extends Pipeline {
}
}

class ZeroShotClassificationPipeline extends Pipeline {
async _call(texts, candidate_labels, {
hypothesis_template = "This example is {}.",
multi_label = false,
} = {}) {

let isBatched = Array.isArray(texts);

if (!isBatched) {
texts = [texts];
}
if (!Array.isArray(candidate_labels)) {
candidate_labels = [candidate_labels];
}

// Insert labels into hypothesis template
let hypotheses = candidate_labels.map(
x => hypothesis_template.replace('{}', x)
);

// How to perform the softmax over the logits:
// - true: softmax over the entailment vs. contradiction dim for each label independently
// - false: softmax the "entailment" logits over all candidate labels
let softmaxEach = multi_label || candidate_labels.length === 1;

let toReturn = [];
for (let premise of texts) {
let entails_logits = [];

for (let hypothesis of hypotheses) {
let inputs = this.tokenizer(premise, {
text_pair: hypothesis,
})
let outputs = await this.model(inputs)

// TODO do not assume (2) is entailment. Better to use model.id2label
if (softmaxEach) {
entails_logits.push([outputs.logits.data[0], outputs.logits.data[2]])
} else {
entails_logits.push(outputs.logits.data[2])
}
}

let scores;
if (softmaxEach) {
scores = entails_logits.map(x => softmax(x)[1]);
} else {
scores = softmax(entails_logits);
}

// Sort by scores (desc) and return scores with indices
let scores_sorted = scores
.map((x, i) => [x, i])
.sort((a, b) => {
return b[0] - a[0];
});

toReturn.push({
sequence: premise,
labels: scores_sorted.map(x => candidate_labels[x[1]]),
scores: scores_sorted.map(x => x[0]),
});
}
return isBatched ? toReturn : toReturn[0];
}
}


class EmbeddingsPipeline extends Pipeline {
// Should only be used with sentence-transformers
Expand Down Expand Up @@ -723,6 +790,15 @@ const SUPPORTED_TASKS = {
},
"type": "text",
},
"zero-shot-classification": {
"tokenizer": AutoTokenizer,
"pipeline": ZeroShotClassificationPipeline,
"model": AutoModelForSequenceClassification,
"default": {
"model": "facebook/bart-large-mnli",
},
"type": "text",
},

"automatic-speech-recognition": {
"tokenizer": AutoTokenizer,
Expand Down Expand Up @@ -807,6 +883,7 @@ const TASK_NAME_MAPPING = {
'image-to-text': 'vision2seq-lm-with-past',

'zero-shot-image-classification': 'default',
'zero-shot-classification': 'sequence-classification'
}

const TASK_PREFIX_MAPPING = {
Expand Down
55 changes: 55 additions & 0 deletions tests/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,60 @@ async function text_classification() {
), duration];
}


async function zero_shot_classification() {
let classifier = await pipeline('zero-shot-classification', 'facebook/bart-large-mnli');


let sequences_to_classify = ['one day I will see the world', 'I love making pizza'];
let candidate_labels = ['travel', 'cooking', 'dancing'];

let start = performance.now();

let outputs1 = await classifier(sequences_to_classify[0], candidate_labels);
let outputs2 = await classifier(sequences_to_classify, candidate_labels);
let outputs3 = await classifier(sequences_to_classify, candidate_labels, {
multi_label: true
})

let duration = performance.now() - start;

// Dispose pipeline
await classifier.dispose()

return [isDeepEqual(
outputs1,
{
sequence: "one day I will see the world",
labels: ["travel", "dancing", "cooking"],
scores: [0.4261703487477968, 0.2903585771517135, 0.28347107410048983]
}
) && isDeepEqual(
outputs2,
[{
sequence: "one day I will see the world",
labels: ["travel", "dancing", "cooking"],
scores: [0.4261703487477968, 0.2903585771517135, 0.28347107410048983]
}, {
sequence: "I love making pizza",
labels: ["cooking", "travel", "dancing"],
scores: [0.4660367922118968, 0.2756005926506238, 0.2583626151374795]
}]
) && isDeepEqual(
outputs3,
[{
sequence: "one day I will see the world",
labels: ["travel", "dancing", "cooking"],
scores: [0.7108286792234982, 0.5763787804099745, 0.44303326070949994]
}, {
sequence: "I love making pizza",
labels: ["cooking", "travel", "dancing"],
scores: [0.8527619536354446, 0.7899589317978243, 0.5838912691496106]
}]
), duration];

}

async function masked_language_modelling() {

let unmasker = await pipeline('fill-mask', 'bert-base-uncased');
Expand Down Expand Up @@ -774,6 +828,7 @@ console.warn = (...data) => {
// Define tests
let tests = {
'Text classification:': text_classification,
'Zero-shot classification': zero_shot_classification,
'Masked language modelling:': masked_language_modelling,
'Question answering:': question_answering,
'Summarization:': summarization,
Expand Down

0 comments on commit 230984e

Please sign in to comment.