Skip to content

Commit

Permalink
Merge pull request #72 from Garvys/task/rm_epsilon_revamp
Browse files Browse the repository at this point in the history
RmEpsilon revamp : Dynamic version
  • Loading branch information
Alexandre Caulier authored Feb 3, 2020
2 parents cb414d8 + 21e3699 commit a535ba8
Show file tree
Hide file tree
Showing 32 changed files with 742 additions and 359 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `union` -> `UnionFst`
- `concat` -> `ConcatFst`
- `closure` -> `ClosureFst`
- `rmepsilon` -> `RmEpsilonFst`
- Added `delete_final_weight_unchecked` to the `Fst` trait and implement it for `VectorFst`.
- Added `SerializableSemiring` trait and implement it for most `Semiring`s.
- All `Fst` that implements `SerializableFst` with a `Semiring` implementing `SerializableSemiring` can now be serialized/deserialized consistently with OpenFst.
Expand Down Expand Up @@ -55,6 +56,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `MutableFst` now has a trait bound on `ExpandedFst`.
- `DrawingConfig` parameters `size`, `ranksep` and `nodesep` are now optional.
- Fix SymbolTable conservation for `Reverse` and `ShortestPath`.
- `RmEpsilon` now mutates its input.
- `dfs_visit` now accepts an `ArcFilter` to be able to skip some arcs.
- `AutoQueue` and `TopOrderQueue` now take an `ArcFilter` in input.
- Remove `Fst` trait bound on `Clone` and `PartialEq`. However this is mandatory to be an `ExpandedFst`.
- `rmepsilon` no longer requires the `Semiring` to be a `StarSemiring`.
- Revamped RmEpsilon and ShortestDistance implementations in order to be closer to OpenFst's one.

## [0.4.0] - 2019-11-12

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ fn main() -> Fallible<()> {
project(&mut fst, ProjectType::ProjectInput);

// - Remove epsilon transitions.
fst = rm_epsilon(&fst)?;
rm_epsilon(&mut fst)?;

// - Compute an equivalent FST but deterministic.
fst = determinize(&fst, DeterminizeType::DeterminizeFunctional)?;
Expand Down
12 changes: 9 additions & 3 deletions rustfst-tests-data/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,16 @@ void compute_fst_reverse(const F& raw_fst, json& j) {

template<class F>
void compute_fst_remove_epsilon(const F& raw_fst, json& j) {
using Arc = typename F::Arc;
auto fst_out = *raw_fst.Copy();
// Connect = false
fst::RmEpsilon(&fst_out, false);
j["rmepsilon"]["result"] = fst_to_string(fst_out);

auto dyn_rmeps = fst::VectorFst<Arc>(fst::RmEpsilonFst<Arc>(raw_fst));

fst::RmEpsilon(&fst_out);
j["rmepsilon"]["result_static"] = fst_to_string(fst_out);
j["rmepsilon"]["result_dynamic"] = fst_to_string(dyn_rmeps);


}

template<class F>
Expand Down
6 changes: 5 additions & 1 deletion rustfst/src/algorithms/arc_filters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ use crate::Arc;
use crate::EPS_LABEL;

/// Base trait to restrict which arcs are traversed in an FST.
pub trait ArcFilter<S: Semiring> {
pub trait ArcFilter<S: Semiring>: Clone {
/// If true, Arc should be kept, else Arc should be ignored.
fn keep(&self, arc: &Arc<S>) -> bool;
}

/// True for all arcs.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct AnyArcFilter {}

impl<S: Semiring> ArcFilter<S> for AnyArcFilter {
Expand All @@ -18,6 +19,7 @@ impl<S: Semiring> ArcFilter<S> for AnyArcFilter {
}

/// True for (input/output) epsilon arcs.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct EpsilonArcFilter {}

impl<S: Semiring> ArcFilter<S> for EpsilonArcFilter {
Expand All @@ -27,6 +29,7 @@ impl<S: Semiring> ArcFilter<S> for EpsilonArcFilter {
}

/// True for input epsilon arcs.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct InputEpsilonArcFilter {}

impl<S: Semiring> ArcFilter<S> for InputEpsilonArcFilter {
Expand All @@ -36,6 +39,7 @@ impl<S: Semiring> ArcFilter<S> for InputEpsilonArcFilter {
}

/// True for output epsilon arcs.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct OutputEpsilonArcFilter {}

impl<S: Semiring> ArcFilter<S> for OutputEpsilonArcFilter {
Expand Down
2 changes: 1 addition & 1 deletion rustfst/src/algorithms/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ where
}
}

#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, PartialEq)]
pub struct ClosureFst<F: Fst + 'static>(ReplaceFst<F, F>)
where
F::W: 'static;
Expand Down
3 changes: 2 additions & 1 deletion rustfst/src/algorithms/connect.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use failure::Fallible;
use unsafe_unwrap::UnsafeUnwrap;

use crate::algorithms::arc_filters::AnyArcFilter;
use crate::algorithms::dfs_visit::{dfs_visit, Visitor};
use crate::fst_traits::Fst;
use crate::fst_traits::{CoreFst, ExpandedFst, MutableFst};
Expand Down Expand Up @@ -46,7 +47,7 @@ use crate::NO_STATE_ID;
///
pub fn connect<F: ExpandedFst + MutableFst>(fst: &mut F) -> Fallible<()> {
let mut visitor = ConnectVisitor::new(fst);
dfs_visit(fst, &mut visitor, false);
dfs_visit(fst, &mut visitor, &AnyArcFilter {}, false);
let mut dstates = Vec::with_capacity(visitor.access.len());
for s in 0..visitor.access.len() {
if !visitor.access[s] || !visitor.coaccess[s] {
Expand Down
8 changes: 7 additions & 1 deletion rustfst/src/algorithms/dfs_visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::fst_traits::{ArcIterator, ExpandedFst, Fst};
use crate::semirings::Semiring;
use crate::StateId;

use crate::algorithms::arc_filters::ArcFilter;
use unsafe_unwrap::UnsafeUnwrap;

#[derive(PartialOrd, PartialEq, Copy, Clone)]
Expand Down Expand Up @@ -90,9 +91,10 @@ impl<I: Iterator> OpenFstIterator<I> {
}
}

pub fn dfs_visit<'a, F: Fst + ExpandedFst, V: Visitor<'a, F>>(
pub fn dfs_visit<'a, F: Fst + ExpandedFst, V: Visitor<'a, F>, A: ArcFilter<F::W>>(
fst: &'a F,
visitor: &mut V,
arc_filter: &A,
access_only: bool,
) {
visitor.init_visit(fst);
Expand Down Expand Up @@ -137,6 +139,10 @@ pub fn dfs_visit<'a, F: Fst + ExpandedFst, V: Visitor<'a, F>>(
}
let arc = aiter.value();
let next_color = state_color[arc.nextstate];
if !(arc_filter.keep(arc)) {
aiter.next();
continue;
}
match next_color {
DfsStateColor::White => {
dfs = visitor.tree_arc(s, arc);
Expand Down
29 changes: 0 additions & 29 deletions rustfst/src/algorithms/dynamic_fst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,6 @@ macro_rules! dynamic_fst {
}
}

impl<$($a: $b $( < $c >)? ),*> PartialEq for $dyn_fst {
fn eq(&self, other: &Self) -> bool {
let ptr = self.fst_impl.get();
let fst_impl = unsafe { ptr.as_ref().unwrap() };

let ptr_other = other.fst_impl.get();
let fst_impl_other = unsafe { ptr_other.as_ref().unwrap() };

fst_impl.eq(fst_impl_other)
}
}

impl<$($a: $b $( < $c >)? ),*> std::fmt::Debug for $dyn_fst
where
$($d: $e,)?
Expand All @@ -44,23 +32,6 @@ macro_rules! dynamic_fst {
}
}

impl<$($a: $b $( < $c >)? ),*> Clone for $dyn_fst
where
$($d: $e,)?
F::W: 'static,
$($a : 'static),*
{
fn clone(&self) -> Self {
let ptr = self.fst_impl.get();
let fst_impl = unsafe { ptr.as_ref().unwrap() };
Self {
fst_impl: UnsafeCell::new(fst_impl.clone()),
isymt: self.input_symbols(),
osymt: self.output_symbols(),
}
}
}

impl<$($a: $b $( < $c >)? ),*> CoreFst for $dyn_fst
where
$($d: $e,)?
Expand Down
4 changes: 2 additions & 2 deletions rustfst/src/algorithms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ pub use self::{
replace::{replace, BorrowFst, ReplaceFst},
reverse::reverse,
reweight::{reweight, ReweightType},
rm_epsilon::rm_epsilon,
rm_epsilon::{rm_epsilon, RmEpsilonFst},
rm_final_epsilon::rm_final_epsilon,
shortest_distance::{shortest_distance, single_source_shortest_distance},
shortest_distance::shortest_distance,
shortest_path::shortest_path,
state_sort::state_sort,
top_sort::top_sort,
Expand Down
2 changes: 1 addition & 1 deletion rustfst/src/algorithms/push.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub fn push_weights<F>(
) -> Fallible<()>
where
F: MutableFst,
F::W: WeaklyDivisibleSemiring,
F::W: WeaklyDivisibleSemiring + 'static,
<<F as CoreFst>::W as Semiring>::ReverseWeight: 'static,
{
let dist = shortest_distance(fst, reweight_type == ReweightType::ReweightToInitial)?;
Expand Down
22 changes: 18 additions & 4 deletions rustfst/src/algorithms/queues/auto_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@ use super::{
natural_less, FifoQueue, LifoQueue, NaturalShortestFirstQueue, SccQueue, StateOrderQueue,
TopOrderQueue, TrivialQueue,
};
use crate::algorithms::arc_filters::ArcFilter;

#[derive(Debug)]
pub struct AutoQueue {
queue: Box<dyn Queue>,
}

impl AutoQueue {
pub fn new<F: MutableFst + ExpandedFst>(fst: &F, distance: Option<&Vec<F::W>>) -> Fallible<Self>
pub fn new<F: MutableFst + ExpandedFst, A: ArcFilter<F::W>>(
fst: &F,
distance: Option<&Vec<F::W>>,
arc_filter: &A,
) -> Fallible<Self>
where
F::W: 'static,
{
Expand All @@ -29,14 +34,14 @@ impl AutoQueue {
if props.contains(FstProperties::TOP_SORTED) || fst.start().is_none() {
queue = Box::new(StateOrderQueue::default());
} else if props.contains(FstProperties::ACYCLIC) {
queue = Box::new(TopOrderQueue::new(fst));
queue = Box::new(TopOrderQueue::new(fst, arc_filter));
} else if props.contains(FstProperties::UNWEIGHTED)
&& F::W::properties().contains(SemiringProperties::IDEMPOTENT)
{
queue = Box::new(LifoQueue::default());
} else {
let mut scc_visitor = SccVisitor::new(fst, true, false);
dfs_visit(fst, &mut scc_visitor, false);
dfs_visit(fst, &mut scc_visitor, arc_filter, false);
let sccs: Vec<_> = scc_visitor
.scc
.unwrap()
Expand Down Expand Up @@ -65,6 +70,7 @@ impl AutoQueue {
&mut queue_types,
&mut all_trivial,
&mut unweighted,
arc_filter,
)?;

if unweighted {
Expand Down Expand Up @@ -94,13 +100,18 @@ impl AutoQueue {
Ok(Self { queue })
}

pub fn scc_queue_type<F: MutableFst + ExpandedFst, C: Fn(&F::W, &F::W) -> Fallible<bool>>(
pub fn scc_queue_type<
F: MutableFst + ExpandedFst,
C: Fn(&F::W, &F::W) -> Fallible<bool>,
A: ArcFilter<F::W>,
>(
fst: &F,
sccs: &[usize],
compare: Option<C>,
queue_types: &mut Vec<QueueType>,
all_trivial: &mut bool,
unweighted: &mut bool,
arc_filter: &A,
) -> Fallible<()> {
*all_trivial = true;
*unweighted = true;
Expand All @@ -111,6 +122,9 @@ impl AutoQueue {

for state in 0..fst.num_states() {
for arc in unsafe { fst.arcs_iter_unchecked(state) } {
if !arc_filter.keep(arc) {
continue;
}
if sccs[state] == sccs[arc.nextstate] {
let queue_type = unsafe { queue_types.get_unchecked_mut(sccs[state]) };
if compare.is_none() || compare.as_ref().unwrap()(&arc.weight, &F::W::one())? {
Expand Down
2 changes: 1 addition & 1 deletion rustfst/src/algorithms/queues/fifo_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::algorithms::{Queue, QueueType};
use crate::StateId;

/// First-in, first-out (queue) queue discipline.
#[derive(Debug, Default)]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct FifoQueue(VecDeque<StateId>);

impl Queue for FifoQueue {
Expand Down
2 changes: 1 addition & 1 deletion rustfst/src/algorithms/queues/lifo_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::algorithms::{Queue, QueueType};
use crate::StateId;

/// Last-in, first-out (stack) queue discipline.
#[derive(Debug, Default)]
#[derive(Debug, Default, Clone)]
pub struct LifoQueue(Vec<StateId>);

impl Queue for LifoQueue {
Expand Down
1 change: 1 addition & 0 deletions rustfst/src/algorithms/queues/shortest_first_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub fn natural_less<W: Semiring>(w1: &W, w2: &W) -> Fallible<bool> {
Ok((&w1.plus(w2)? == w1) && (w1 != w2))
}

#[derive(Clone)]
pub struct ShortestFirstQueue<C: Clone + FnMut(&StateId, &StateId) -> Ordering> {
heap: BinaryHeap<StateId, FnComparator<C>>,
}
Expand Down
2 changes: 1 addition & 1 deletion rustfst/src/algorithms/queues/state_order_queue.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::algorithms::{Queue, QueueType};
use crate::StateId;

#[derive(Debug, Default)]
#[derive(Debug, Default, Clone)]
pub struct StateOrderQueue {
front: StateId,
back: Option<StateId>,
Expand Down
7 changes: 4 additions & 3 deletions rustfst/src/algorithms/queues/top_order_queue.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::algorithms::arc_filters::ArcFilter;
use crate::algorithms::dfs_visit::dfs_visit;
use crate::algorithms::top_sort::TopOrderVisitor;
use crate::algorithms::{Queue, QueueType};
Expand All @@ -6,7 +7,7 @@ use crate::StateId;

/// Topological-order queue discipline, templated on the StateId. States are
/// ordered in the queue topologically. The FST must be acyclic.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct TopOrderQueue {
order: Vec<StateId>,
state: Vec<Option<StateId>>,
Expand All @@ -15,9 +16,9 @@ pub struct TopOrderQueue {
}

impl TopOrderQueue {
pub fn new<F: MutableFst + ExpandedFst>(fst: &F) -> Self {
pub fn new<F: MutableFst + ExpandedFst, A: ArcFilter<F::W>>(fst: &F, arc_filter: &A) -> Self {
let mut visitor = TopOrderVisitor::new();
dfs_visit(fst, &mut visitor, false);
dfs_visit(fst, &mut visitor, arc_filter, false);
if !visitor.acyclic {
panic!("Unexpectted Acyclic FST for TopOprerQueue");
}
Expand Down
2 changes: 1 addition & 1 deletion rustfst/src/algorithms/queues/trivial_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::StateId;
/// Trivial queue discipline; one may enqueue at most one state at a time. It
/// can be used for strongly connected components with only one state and no
/// self-loops.
#[derive(Debug, Default)]
#[derive(Debug, Default, Clone)]
pub struct TrivialQueue {
state: Option<StateId>,
}
Expand Down
Loading

0 comments on commit a535ba8

Please sign in to comment.