Skip to content

Commit

Permalink
Iterable Dataloader Improvements (#91)
Browse files Browse the repository at this point in the history
* Iterable Dataloader

* Iterable Dataloader Improvements
  • Loading branch information
ibanesh authored Nov 20, 2023
1 parent 4f3f3f3 commit ec759d1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
2 changes: 2 additions & 0 deletions simuleval/data/dataloader/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def add_args(parser: ArgumentParser):


class IterableDataloader:
cur_index: int

@abstractmethod
def __iter__(self):
...
Expand Down
34 changes: 16 additions & 18 deletions simuleval/evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,25 @@
# LICENSE file in the root directory of this source tree.

import contextlib
import pandas
import os
import json
import logging
import numbers
import os
from argparse import Namespace
from pathlib import Path
from typing import Dict, Generator, Optional

import pandas
import yaml
from simuleval.data.dataloader import GenericDataloader, build_dataloader
from simuleval.data.dataloader.dataloader import IterableDataloader
from tqdm import tqdm

from .instance import INSTANCE_TYPE_DICT, LogInstance
from .scorers import get_scorer_class
from .scorers.latency_scorer import LatencyScorer
from .scorers.quality_scorer import QualityScorer

from .instance import INSTANCE_TYPE_DICT, LogInstance
import yaml
import logging
import json
from tqdm import tqdm
from pathlib import Path
from simuleval.data.dataloader import GenericDataloader, build_dataloader

try:
import sentencepiece

Expand Down Expand Up @@ -170,10 +170,9 @@ def build_instances_from_log(self):
with open(self.output / "instances.log", "r") as f:
for line in f:
instance = LogInstance(line.strip())
self.instances[instance.index] = instance
self.instances[instance.index].set_target_spm_model(
self.target_spm_model
)
index = instance.index - self.start_index
self.instances[index] = instance
self.instances[index].set_target_spm_model(self.target_spm_model)

def build_instances_from_dataloader(self):
if isinstance(self.dataloader, IterableDataloader):
Expand Down Expand Up @@ -245,11 +244,12 @@ def __call__(self, system):
with open(
self.output / "instances.log", "a"
) if self.output else contextlib.nullcontext() as file:
idx = 0
system.reset()
for sample in self.iterator:
instance = (
self.instance_class(idx, self.dataloader, self.args)
self.instance_class(
self.dataloader.cur_index, self.dataloader, self.args
)
if isinstance(self.dataloader, IterableDataloader)
else sample
)
Expand All @@ -268,8 +268,6 @@ def __call__(self, system):
if not self.score_only and self.output:
file.write(json.dumps(instance.summarize()) + "\n")

idx += 1

if self.output:
self.build_instances_from_log()
self.dump_results()
Expand Down

0 comments on commit ec759d1

Please sign in to comment.