Skip to content

Commit

Permalink
Replace magic number 4 with NUM_STATES. Add comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
awong-dev committed Apr 14, 2024
1 parent 46e4273 commit ae1d702
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 38 deletions.
6 changes: 3 additions & 3 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ fn main() {
let mut lines = reader.lines().map(|x| x.unwrap()).skip_while(|x| x.starts_with('#'));
let prob_start = lines.next().unwrap();
writeln!(&mut file, "#[allow(clippy::style)]").unwrap();
write!(&mut file, "static INITIAL_PROBS: StatusSet = [").unwrap();
write!(&mut file, "static INITIAL_PROBS: StateSet = [").unwrap();
for prob in prob_start.split(' ') {
write!(&mut file, "{}, ", prob).unwrap();
}
write!(&mut file, "];\n\n").unwrap();
writeln!(&mut file, "#[allow(clippy::style)]").unwrap();
write!(&mut file, "static TRANS_PROBS: [StatusSet; 4] = [").unwrap();
write!(&mut file, "static TRANS_PROBS: [StateSet; crate::hmm::NUM_STATES] = [").unwrap();
for line in lines
.by_ref()
.skip_while(|x| x.starts_with('#'))
Expand Down Expand Up @@ -50,5 +50,5 @@ fn main() {
i += 1;
}
writeln!(&mut file, "#[allow(clippy::style)]").unwrap();
writeln!(&mut file, "static EMIT_PROBS: [&'static phf::Map<&'static str, f64>; 4] = [&EMIT_PROB_0, &EMIT_PROB_1, &EMIT_PROB_2, &EMIT_PROB_3];").unwrap();
writeln!(&mut file, "static EMIT_PROBS: [&'static phf::Map<&'static str, f64>; crate::hmm::NUM_STATES] = [&EMIT_PROB_0, &EMIT_PROB_1, &EMIT_PROB_2, &EMIT_PROB_3];").unwrap();
}
87 changes: 54 additions & 33 deletions src/hmm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,47 +10,67 @@ lazy_static! {
static ref RE_SKIP: Regex = Regex::new(r"([a-zA-Z0-9]+(?:.\d+)?%?)").unwrap();
}

pub type StatusSet = [f64; 4];

pub const NUM_STATES: usize = 4;

pub type StateSet = [f64; NUM_STATES];

/// Result of hmm is a labeling of each Unicode Scalar Value in the input
/// string with Begin, Middle, End, or Single. These denote the proposed
/// segments. A segment is one of the following two patterns.
///
/// Begin, [Middle, Middle, ...], End
/// Single
///
/// Each state in the enum is also assigned an index value from 0-3 that
/// can be used as an index into an array representing data pertaining
/// to that state.
///
/// WARNING: the data file format for hmm.model comments imply one can
/// reassign the index values of each state at the top but `build.rs`
/// currently ignores the mapping. Do not reassign these indicies without
/// verifying hot it interacts with `build.rs`
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Clone, Copy)]
pub enum Status {
B = 0,
E = 1,
M = 2,
S = 3,
pub enum State {
Begin = 0,
End = 1,
Middle = 2,
Single = 3,
}

static PREV_STATUS: [[Status; 2]; 4] = [
[Status::E, Status::S], // B
[Status::B, Status::M], // E
[Status::M, Status::B], // M
[Status::S, Status::E], // S
];
// Mapping representing the allow transitiongs into the given state.
static ALLOWED_PREV_STATUS: [[State; 2]; NUM_STATES] = {
let mut valid_transitions_from: [[State; 2]; NUM_STATES] = [[State::Begin, State::Begin]; NUM_STATES];
valid_transitions_from[State::Begin as usize] = [State::End, State::Single];
valid_transitions_from[State::End as usize] = [State::Begin, State::Middle];
valid_transitions_from[State::Middle as usize] = [State::Middle, State::Begin];
valid_transitions_from[State::Single as usize] = [State::Single, State::End];
valid_transitions_from
};

include!(concat!(env!("OUT_DIR"), "/hmm_prob.rs"));

const MIN_FLOAT: f64 = -3.14e100;

pub(crate) struct HmmContext {
v: Vec<f64>,
prev: Vec<Option<Status>>,
best_path: Vec<Status>,
prev: Vec<Option<State>>,
best_path: Vec<State>,
}

impl HmmContext {
pub fn new(num_states: usize, num_characters: usize) -> Self {
pub fn new(num_characters: usize) -> Self {
HmmContext {
v: vec![0.0; num_states * num_characters],
prev: vec![None; num_states * num_characters],
best_path: vec![Status::B; num_characters],
v: vec![0.0; NUM_STATES * num_characters],
prev: vec![None; NUM_STATES * num_characters],
best_path: vec![State::Begin; num_characters],
}
}
}

#[allow(non_snake_case)]
fn viterbi(sentence: &str, hmm_context: &mut HmmContext) {
let str_len = sentence.len();
let states = [Status::B, Status::M, Status::E, Status::S];
let states = [State::Begin, State::Middle, State::End, State::Single];
#[allow(non_snake_case)]
let R = states.len();
let C = sentence.chars().count();
Expand All @@ -66,7 +86,7 @@ fn viterbi(sentence: &str, hmm_context: &mut HmmContext) {
}

if hmm_context.best_path.len() < C {
hmm_context.best_path.resize(C, Status::B);
hmm_context.best_path.resize(C, State::Begin);
}

let mut curr = sentence.char_indices().map(|x| x.0).peekable();
Expand All @@ -84,7 +104,7 @@ fn viterbi(sentence: &str, hmm_context: &mut HmmContext) {
let byte_end = *curr.peek().unwrap_or(&str_len);
let word = &sentence[byte_start..byte_end];
let em_prob = EMIT_PROBS[*y as usize].get(word).cloned().unwrap_or(MIN_FLOAT);
let (prob, state) = PREV_STATUS[*y as usize]
let (prob, state) = ALLOWED_PREV_STATUS[*y as usize]
.iter()
.map(|y0| {
(
Expand All @@ -104,7 +124,7 @@ fn viterbi(sentence: &str, hmm_context: &mut HmmContext) {
t += 1;
}

let (_prob, state) = [Status::E, Status::S]
let (_prob, state) = [State::End, State::Single]
.iter()
.map(|y| (hmm_context.v[(C - 1) * R + (*y as usize)], y))
.max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal))
Expand Down Expand Up @@ -137,20 +157,20 @@ pub fn cut_internal<'a>(sentence: &'a str, words: &mut Vec<&'a str>, hmm_context
while let Some(curr_byte_offset) = curr.next() {
let state = hmm_context.best_path[i];
match state {
Status::B => begin = curr_byte_offset,
Status::E => {
State::Begin => begin = curr_byte_offset,
State::End => {
let byte_start = begin;
let byte_end = *curr.peek().unwrap_or(&str_len);
words.push(&sentence[byte_start..byte_end]);
next_byte_offset = byte_end;
}
Status::S => {
State::Single => {
let byte_start = curr_byte_offset;
let byte_end = *curr.peek().unwrap_or(&str_len);
words.push(&sentence[byte_start..byte_end]);
next_byte_offset = byte_end;
}
Status::M => { /* do nothing */ }
State::Middle => { /* do nothing */ }
}

i += 1;
Expand Down Expand Up @@ -193,8 +213,7 @@ pub(crate) fn cut_with_allocated_memory<'a>(sentence: &'a str, words: &mut Vec<&

#[allow(non_snake_case)]
pub fn cut<'a>(sentence: &'a str, words: &mut Vec<&'a str>) {
// TODO: Is 4 just the number of variants in Status?
let mut hmm_context = HmmContext::new(4, sentence.chars().count());
let mut hmm_context = HmmContext::new(sentence.chars().count());

cut_with_allocated_memory(sentence, words, &mut hmm_context)
}
Expand All @@ -206,14 +225,16 @@ mod tests {
#[test]
#[allow(non_snake_case)]
fn test_viterbi() {
use super::Status::*;
use super::State::*;

let sentence = "小明硕士毕业于中国科学院计算所";

// TODO: Is 4 just the number of variants in Status?
let mut hmm_context = HmmContext::new(4, sentence.chars().count());
let mut hmm_context = HmmContext::new(sentence.chars().count());
viterbi(sentence, &mut hmm_context);
assert_eq!(hmm_context.best_path, vec![B, E, B, E, B, M, E, B, E, B, M, E, B, E, S]);
assert_eq!(
hmm_context.best_path,
vec![Begin, End, Begin, End, Begin, Middle, End, Begin, End, Begin, Middle, End, Begin, End, Single]
);
}

#[test]
Expand Down
3 changes: 1 addition & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -580,8 +580,7 @@ impl Jieba {
let mut route = Vec::with_capacity(heuristic_capacity);
let mut dag = StaticSparseDAG::with_size_hint(heuristic_capacity);

// TODO: Is 4 just the number of variants in Status?
let mut hmm_context = hmm::HmmContext::new(4, sentence.chars().count());
let mut hmm_context = hmm::HmmContext::new(sentence.chars().count());

for state in splitter {
match state {
Expand Down

0 comments on commit ae1d702

Please sign in to comment.