Generify login

This commit is contained in:
Aram 🍐 2021-07-31 23:55:12 -04:00
parent db8158eea7
commit 77d3020df5
3 changed files with 122 additions and 74 deletions

View file

@ -16,9 +16,28 @@ pub enum Connection {
impl Connection { impl Connection {
/// Initializes a connection to a NUT server (upsd). /// Initializes a connection to a NUT server (upsd).
pub fn new(config: &Config) -> crate::Result<Self> { pub fn new(config: &Config) -> crate::Result<Self> {
match &config.host { let mut conn = match &config.host {
Host::Tcp(host) => Ok(Self::Tcp(TcpConnection::new(config.clone(), &host.addr)?)), Host::Tcp(host) => Self::Tcp(TcpConnection::new(config.clone(), &host.addr)?),
};
conn.get_network_version()?;
conn.login(config)?;
Ok(conn)
}
/// Sends username and password, as applicable.
fn login(&mut self, config: &Config) -> crate::Result<()> {
if let Some(auth) = config.auth.clone() {
// Pass username and check for 'OK'
self.set_username(&auth.username)?;
// Pass password and check for 'OK'
if let Some(password) = &auth.password {
self.set_password(password)?;
}
} }
Ok(())
} }
} }
@ -36,25 +55,18 @@ impl TcpConnection {
config, config,
stream: ConnectionStream::Plain(tcp_stream), stream: ConnectionStream::Plain(tcp_stream),
}; };
// Initialize SSL connection
connection = connection.enable_ssl()?; connection = connection.enable_ssl()?;
// Attempt login using `config.auth`
connection.login()?;
Ok(connection) Ok(connection)
} }
#[cfg(feature = "ssl")] #[cfg(feature = "ssl")]
fn enable_ssl(mut self) -> crate::Result<Self> { fn enable_ssl(mut self) -> crate::Result<Self> {
if self.config.ssl { if self.config.ssl {
// Send TLS request and check for 'OK'
self.write_cmd(Command::StartTLS)?; self.write_cmd(Command::StartTLS)?;
self.read_response() self.read_response()
.map_err(|e| { .map_err(|e| {
if let ClientError::Nut(NutError::FeatureNotConfigured) = e { if let crate::ClientError::Nut(NutError::FeatureNotConfigured) = e {
ClientError::Nut(NutError::SslNotSupported) crate::ClientError::Nut(NutError::SslNotSupported)
} else { } else {
e e
} }
@ -92,10 +104,6 @@ impl TcpConnection {
// Wrap and override the TCP stream // Wrap and override the TCP stream
self.stream = self.stream.upgrade_ssl(sess)?; self.stream = self.stream.upgrade_ssl(sess)?;
// Send a test command
self.write_cmd(Command::NetworkVersion)?;
self.read_plain_response()?;
} }
Ok(self) Ok(self)
} }
@ -105,21 +113,6 @@ impl TcpConnection {
Ok(self) Ok(self)
} }
fn login(&mut self) -> crate::Result<()> {
if let Some(auth) = self.config.auth.clone() {
// Pass username and check for 'OK'
self.write_cmd(Command::SetUsername(&auth.username))?;
self.read_response()?.expect_ok()?;
// Pass password and check for 'OK'
if let Some(password) = &auth.password {
self.write_cmd(Command::SetPassword(password))?;
self.read_response()?.expect_ok()?;
}
}
Ok(())
}
pub(crate) fn write_cmd(&mut self, line: Command) -> crate::Result<()> { pub(crate) fn write_cmd(&mut self, line: Command) -> crate::Result<()> {
let line = format!("{}\n", line); let line = format!("{}\n", line);
if self.config.debug { if self.config.debug {

View file

@ -248,7 +248,7 @@ macro_rules! implement_list_commands {
( (
$( $(
$(#[$attr:meta])+ $(#[$attr:meta])+
fn $name:ident($($argname:ident: $argty:ty),*) -> $retty:ty { $vis:vis fn $name:ident($($argname:ident: $argty:ty),*) -> $retty:ty {
( (
$query:block, $query:block,
$mapper:block, $mapper:block,
@ -259,7 +259,8 @@ macro_rules! implement_list_commands {
impl crate::blocking::Connection { impl crate::blocking::Connection {
$( $(
$(#[$attr])* $(#[$attr])*
pub fn $name(&mut self$(, $argname: $argty)*) -> crate::Result<$retty> { #[allow(dead_code)]
$vis fn $name(&mut self$(, $argname: $argty)*) -> crate::Result<$retty> {
match self { match self {
Self::Tcp(conn) => { Self::Tcp(conn) => {
conn.write_cmd(Command::List($query))?; conn.write_cmd(Command::List($query))?;
@ -275,7 +276,8 @@ macro_rules! implement_list_commands {
impl crate::tokio::Connection { impl crate::tokio::Connection {
$( $(
$(#[$attr])* $(#[$attr])*
pub async fn $name(&mut self$(, $argname: $argty)*) -> crate::Result<$retty> { #[allow(dead_code)]
$vis async fn $name(&mut self$(, $argname: $argty)*) -> crate::Result<$retty> {
match self { match self {
Self::Tcp(conn) => { Self::Tcp(conn) => {
conn.write_cmd(Command::List($query)).await?; conn.write_cmd(Command::List($query)).await?;
@ -298,7 +300,7 @@ macro_rules! implement_get_commands {
( (
$( $(
$(#[$attr:meta])+ $(#[$attr:meta])+
fn $name:ident($($argname:ident: $argty:ty),*) -> $retty:ty { $vis:vis fn $name:ident($($argname:ident: $argty:ty),*) -> $retty:ty {
( (
$query:block, $query:block,
$mapper:block, $mapper:block,
@ -309,7 +311,8 @@ macro_rules! implement_get_commands {
impl crate::blocking::Connection { impl crate::blocking::Connection {
$( $(
$(#[$attr])* $(#[$attr])*
pub fn $name(&mut self$(, $argname: $argty)*) -> crate::Result<$retty> { #[allow(dead_code)]
$vis fn $name(&mut self$(, $argname: $argty)*) -> crate::Result<$retty> {
match self { match self {
Self::Tcp(conn) => { Self::Tcp(conn) => {
conn.write_cmd(Command::Get($query))?; conn.write_cmd(Command::Get($query))?;
@ -324,7 +327,8 @@ macro_rules! implement_get_commands {
impl crate::tokio::Connection { impl crate::tokio::Connection {
$( $(
$(#[$attr])* $(#[$attr])*
pub async fn $name(&mut self$(, $argname: $argty)*) -> crate::Result<$retty> { #[allow(dead_code)]
$vis async fn $name(&mut self$(, $argname: $argty)*) -> crate::Result<$retty> {
match self { match self {
Self::Tcp(conn) => { Self::Tcp(conn) => {
conn.write_cmd(Command::Get($query)).await?; conn.write_cmd(Command::Get($query)).await?;
@ -346,7 +350,7 @@ macro_rules! implement_simple_commands {
( (
$( $(
$(#[$attr:meta])+ $(#[$attr:meta])+
fn $name:ident($($argname:ident: $argty:ty),*) -> $retty:ty { $vis:vis fn $name:ident($($argname:ident: $argty:ty),*) -> $retty:ty {
( (
$cmd:block, $cmd:block,
$mapper:block, $mapper:block,
@ -357,7 +361,8 @@ macro_rules! implement_simple_commands {
impl crate::blocking::Connection { impl crate::blocking::Connection {
$( $(
$(#[$attr])* $(#[$attr])*
pub fn $name(&mut self$(, $argname: $argty)*) -> crate::Result<$retty> { #[allow(dead_code)]
$vis fn $name(&mut self$(, $argname: $argty)*) -> crate::Result<$retty> {
match self { match self {
Self::Tcp(conn) => { Self::Tcp(conn) => {
conn.write_cmd($cmd)?; conn.write_cmd($cmd)?;
@ -372,7 +377,8 @@ macro_rules! implement_simple_commands {
impl crate::tokio::Connection { impl crate::tokio::Connection {
$( $(
$(#[$attr])* $(#[$attr])*
pub async fn $name(&mut self$(, $argname: $argty)*) -> crate::Result<$retty> { #[allow(dead_code)]
$vis async fn $name(&mut self$(, $argname: $argty)*) -> crate::Result<$retty> {
match self { match self {
Self::Tcp(conn) => { Self::Tcp(conn) => {
conn.write_cmd($cmd).await?; conn.write_cmd($cmd).await?;
@ -385,9 +391,54 @@ macro_rules! implement_simple_commands {
}; };
} }
/// A macro for implementing action commands that return `OK`.
///
/// Each function should return the command to pass.
macro_rules! implement_action_commands {
(
$(
$(#[$attr:meta])+
$vis:vis fn $name:ident($($argname:ident: $argty:ty),*) $cmd:block
)*
) => {
impl crate::blocking::Connection {
$(
$(#[$attr])*
#[allow(dead_code)]
$vis fn $name(&mut self$(, $argname: $argty)*) -> crate::Result<()> {
match self {
Self::Tcp(conn) => {
conn.write_cmd($cmd)?;
conn.read_response()?.expect_ok()?;
Ok(())
},
}
}
)*
}
#[cfg(feature = "async")]
impl crate::tokio::Connection {
$(
$(#[$attr])*
#[allow(dead_code)]
$vis async fn $name(&mut self$(, $argname: $argty)*) -> crate::Result<()> {
match self {
Self::Tcp(conn) => {
conn.write_cmd($cmd).await?;
conn.read_response().await?.expect_ok()?;
Ok(())
},
}
}
)*
}
};
}
implement_list_commands! { implement_list_commands! {
/// Queries a list of UPS devices. /// Queries a list of UPS devices.
fn list_ups() -> Vec<(String, String)> { pub fn list_ups() -> Vec<(String, String)> {
( (
{ &["UPS"] }, { &["UPS"] },
{ |row: Response| row.expect_ups() }, { |row: Response| row.expect_ups() },
@ -395,7 +446,7 @@ implement_list_commands! {
} }
/// Queries a list of client IP addresses connected to the given device. /// Queries a list of client IP addresses connected to the given device.
fn list_clients(ups_name: &str) -> Vec<String> { pub fn list_clients(ups_name: &str) -> Vec<String> {
( (
{ &["CLIENT", ups_name] }, { &["CLIENT", ups_name] },
{ |row: Response| row.expect_client() }, { |row: Response| row.expect_client() },
@ -403,7 +454,7 @@ implement_list_commands! {
} }
/// Queries the list of variables for a UPS device. /// Queries the list of variables for a UPS device.
fn list_vars(ups_name: &str) -> Vec<Variable> { pub fn list_vars(ups_name: &str) -> Vec<Variable> {
( (
{ &["VAR", ups_name] }, { &["VAR", ups_name] },
{ |row: Response| row.expect_var() }, { |row: Response| row.expect_var() },
@ -413,7 +464,7 @@ implement_list_commands! {
implement_get_commands! { implement_get_commands! {
/// Queries one variable for a UPS device. /// Queries one variable for a UPS device.
fn get_var(ups_name: &str, variable: &str) -> Variable { pub fn get_var(ups_name: &str, variable: &str) -> Variable {
( (
{ &["VAR", ups_name, variable] }, { &["VAR", ups_name, variable] },
{ |row: Response| row.expect_var() }, { |row: Response| row.expect_var() },
@ -423,10 +474,22 @@ implement_get_commands! {
implement_simple_commands! { implement_simple_commands! {
/// Queries the network protocol version. /// Queries the network protocol version.
fn get_network_version() -> String { pub fn get_network_version() -> String {
( (
{ Command::NetworkVersion }, { Command::NetworkVersion },
{ |row: String| Ok(row) }, { |row: String| Ok(row) },
) )
} }
} }
implement_action_commands! {
/// Sends the login username.
pub(crate) fn set_username(username: &str) {
Command::SetUsername(username)
}
/// Sends the login password.
pub(crate) fn set_password(password: &str) {
Command::SetPassword(password)
}
}

View file

@ -17,11 +17,28 @@ pub enum Connection {
impl Connection { impl Connection {
/// Initializes a connection to a NUT server (upsd). /// Initializes a connection to a NUT server (upsd).
pub async fn new(config: &Config) -> crate::Result<Self> { pub async fn new(config: &Config) -> crate::Result<Self> {
match &config.host { let mut conn = match &config.host {
Host::Tcp(host) => Ok(Self::Tcp( Host::Tcp(host) => Self::Tcp(TcpConnection::new(config.clone(), &host.addr).await?),
TcpConnection::new(config.clone(), &host.addr).await?, };
)),
conn.get_network_version().await?;
conn.login(config).await?;
Ok(conn)
}
/// Sends username and password, as applicable.
async fn login(&mut self, config: &Config) -> crate::Result<()> {
if let Some(auth) = config.auth.clone() {
// Pass username and check for 'OK'
self.set_username(&auth.username).await?;
// Pass password and check for 'OK'
if let Some(password) = &auth.password {
self.set_password(password).await?;
}
} }
Ok(())
} }
} }
@ -39,13 +56,7 @@ impl TcpConnection {
config, config,
stream: ConnectionStream::Plain(tcp_stream), stream: ConnectionStream::Plain(tcp_stream),
}; };
// Initialize SSL connection
connection = connection.enable_ssl().await?; connection = connection.enable_ssl().await?;
// Attempt login using `config.auth`
connection.login().await?;
Ok(connection) Ok(connection)
} }
@ -99,10 +110,6 @@ impl TcpConnection {
// Wrap and override the TCP stream // Wrap and override the TCP stream
self.stream = self.stream.upgrade_ssl(config, dns_name.as_ref()).await?; self.stream = self.stream.upgrade_ssl(config, dns_name.as_ref()).await?;
// Send a test command
self.write_cmd(Command::NetworkVersion).await?;
self.read_plain_response().await?;
} }
Ok(self) Ok(self)
} }
@ -112,21 +119,6 @@ impl TcpConnection {
Ok(self) Ok(self)
} }
async fn login(&mut self) -> crate::Result<()> {
if let Some(auth) = self.config.auth.clone() {
// Pass username and check for 'OK'
self.write_cmd(Command::SetUsername(&auth.username)).await?;
self.read_response().await?.expect_ok()?;
// Pass password and check for 'OK'
if let Some(password) = &auth.password {
self.write_cmd(Command::SetPassword(password)).await?;
self.read_response().await?.expect_ok()?;
}
}
Ok(())
}
pub(crate) async fn write_cmd(&mut self, line: Command<'_>) -> crate::Result<()> { pub(crate) async fn write_cmd(&mut self, line: Command<'_>) -> crate::Result<()> {
let line = format!("{}\n", line); let line = format!("{}\n", line);
if self.config.debug { if self.config.debug {