Skip to content

Commit

Permalink
Move request id generation and checking to separate wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
dvolodin7 committed Jan 8, 2024
1 parent ad2dfac commit ea66150
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 23 deletions.
36 changes: 13 additions & 23 deletions src/socket/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// See LICENSE.md for details
// ------------------------------------------------------------------------

use super::RequestId;
use crate::ber::{BerEncoder, SnmpOid, ToPython};
use crate::buf::Buffer;
use crate::error::SnmpError;
Expand All @@ -21,7 +22,6 @@ use pyo3::{
prelude::*,
types::{PyDict, PyList, PyTuple},
};
use rand::Rng;
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
use std::net::SocketAddr;
use std::os::fd::AsRawFd;
Expand All @@ -33,7 +33,7 @@ pub struct SnmpClientSocket {
addr: SockAddr,
community: String,
version: SnmpVersion,
request_id: i64,
request_id: RequestId,
buf: Buffer,
}

Expand Down Expand Up @@ -99,7 +99,7 @@ impl SnmpClientSocket {
addr: sock_addr.into(),
community,
version,
request_id: 0,
request_id: RequestId::default(),
buf: Buffer::default(),
})
}
Expand Down Expand Up @@ -129,14 +129,12 @@ impl SnmpClientSocket {
fn send_getnext(&mut self, iter: &GetNextIter) -> PyResult<()> {
// Start from clear buffer
self.buf.reset();
// Get new request id
let request_id = self.new_request_id();
// Encode message
let msg = SnmpMessage {
version: self.version.clone(),
community: self.community.as_ref(),
pdu: SnmpPdu::GetNextRequest(SnmpGet {
request_id,
request_id: self.request_id.next(),
vars: vec![iter.get_next_oid()],
}),
};
Expand All @@ -152,14 +150,12 @@ impl SnmpClientSocket {
fn send_getbulk(&mut self, iter: &GetBulkIter) -> PyResult<()> {
// Start from clear buffer
self.buf.reset();
// Get new request id
let request_id = self.new_request_id();
// Encode message
let msg = SnmpMessage {
version: self.version.clone(),
community: self.community.as_ref(),
pdu: SnmpPdu::GetBulkRequest(SnmpGetBulk {
request_id,
request_id: self.request_id.next(),
non_repeaters: 0,
max_repetitions: iter.get_max_repetitions(),
vars: vec![iter.get_next_oid()],
Expand Down Expand Up @@ -197,7 +193,7 @@ impl SnmpClientSocket {
match msg.pdu {
SnmpPdu::GetResponse(resp) => {
// Check request id
if resp.request_id != self.request_id {
if !self.request_id.check(resp.request_id) {
continue; // Not our request
}
// Check error_index
Expand Down Expand Up @@ -250,7 +246,7 @@ impl SnmpClientSocket {
match msg.pdu {
SnmpPdu::GetResponse(resp) => {
// Check request id
if resp.request_id != self.request_id {
if !self.request_id.check(resp.request_id) {
continue; // Not our request
}
// Check error_index
Expand Down Expand Up @@ -301,7 +297,7 @@ impl SnmpClientSocket {
match msg.pdu {
SnmpPdu::GetResponse(resp) => {
// Check request id
if resp.request_id != self.request_id {
if !self.request_id.check(resp.request_id) {
continue; // Not our request
}
// Check error_index
Expand Down Expand Up @@ -355,7 +351,7 @@ impl SnmpClientSocket {
match msg.pdu {
SnmpPdu::GetResponse(resp) => {
// Check request id
if resp.request_id != self.request_id {
if !self.request_id.check(resp.request_id) {
continue; // Not our request
}
// Check error_index
Expand Down Expand Up @@ -420,24 +416,18 @@ impl SnmpClientSocket {
}
Err(PyOSError::new_err("unable to set buffer size"))
}
//
fn new_request_id(&mut self) -> i64 {
let mut rng = rand::thread_rng();
let x: i64 = rng.gen();
self.request_id = x & 0x7fffffff;
self.request_id
}
/// Send GET request
fn _send_get(&mut self, vars: Vec<SnmpOid>) -> PyResult<()> {
// Start from clear buffer
self.buf.reset();
// Get new request id
let request_id = self.new_request_id();
// Encode message
let msg = SnmpMessage {
version: self.version.clone(),
community: self.community.as_ref(),
pdu: SnmpPdu::GetRequest(SnmpGet { request_id, vars }),
pdu: SnmpPdu::GetRequest(SnmpGet {
request_id: self.request_id.next(),
vars,
}),
};
msg.push_ber(&mut self.buf)
.map_err(|_| PyValueError::new_err("failed to encode message"))?;
Expand Down
2 changes: 2 additions & 0 deletions src/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@
// ------------------------------------------------------------------------

pub mod client;
pub mod reqid;
pub use client::{GetBulkIter, GetNextIter, SnmpClientSocket};
use reqid::RequestId;
51 changes: 51 additions & 0 deletions src/socket/reqid.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// ------------------------------------------------------------------------
// Gufo SNMP: Id Generator
// ------------------------------------------------------------------------
// Copyright (C) 2023-24, Gufo Labs
// See LICENSE.md for details
// ------------------------------------------------------------------------

use rand::Rng;

#[derive(Default)]
pub struct RequestId(i64);

impl RequestId {
/// Get next value
pub fn next(&mut self) -> i64 {
let mut rng = rand::thread_rng();
let x: i64 = rng.gen();
self.0 = x & 0x7fffffff;
self.0
}
/// Check values for match
pub fn check(&self, v: i64) -> bool {
self.0 == v
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_default() {
let r = RequestId::default();
assert!(r.check(0))
}

#[test]
fn test_check() {
let mut r = RequestId::default();
let v1 = r.next();
assert!(r.check(v1))
}

#[test]
fn test_seq() {
let mut r = RequestId::default();
let v1 = r.next();
let v2 = r.next();
assert!(v1 != v2)
}
}

0 comments on commit ea66150

Please sign in to comment.