azalea_core/
socket.rs

1#[derive(Debug)]
2pub enum Error {
3    Read,
4    Write,
5    UnixSocket(String),
6}
7
8pub mod sync {
9    use crate::log;
10
11    use super::Error;
12    use std::{
13        io::{Read, Write},
14        os::unix::net::{UnixListener, UnixStream},
15    };
16
17    pub struct UnixListenerWrapper {
18        listener: UnixListener,
19    }
20
21    impl UnixListenerWrapper {
22        pub fn bind<P>(path: P) -> Result<Self, Error>
23        where
24            P: AsRef<std::path::Path>,
25        {
26            drop(std::fs::remove_file(&path));
27            match UnixListener::bind(path) {
28                Ok(listener) => Ok(Self { listener }),
29                Err(e) => Err(Error::UnixSocket(e.to_string())),
30            }
31        }
32
33        pub fn loop_accept<F>(&self, mut callback: F) -> Result<(), Error>
34        where
35            F: FnMut(UnixStreamWrapper) -> Result<bool, Error>,
36        {
37            loop {
38                match self.listener.accept() {
39                    Err(e) => log::warning!("failed to connect {e:?}"),
40                    Ok((stream, _addr)) => {
41                        let stream = UnixStreamWrapper::new(stream);
42                        match callback(stream) {
43                            Ok(alive) => {
44                                if alive {
45                                    continue;
46                                }
47                                return Ok(());
48                            }
49                            Err(e) => log::warning!("failed to execute callback {e:?}"),
50                        }
51                    }
52                }
53            }
54        }
55    }
56
57    pub struct UnixStreamWrapper {
58        stream: UnixStream,
59    }
60
61    impl UnixStreamWrapper {
62        pub fn new(stream: UnixStream) -> Self {
63            Self { stream }
64        }
65
66        pub fn connect<P>(path: P) -> Result<Self, Error>
67        where
68            P: AsRef<std::path::Path>,
69        {
70            match std::os::unix::net::UnixStream::connect(path) {
71                Ok(stream) => Ok(UnixStreamWrapper::new(stream)),
72                Err(e) => Err(Error::UnixSocket(e.to_string())),
73            }
74        }
75
76        pub fn read<T>(&mut self) -> Result<T, Error>
77        where
78            T: serde::de::DeserializeOwned,
79        {
80            let mut response = vec![];
81            drop(self.stream.read_to_end(&mut response));
82            match serde_json::from_slice(&response) {
83                Ok(response) => Ok(response),
84                Err(_) => Err(Error::Read),
85            }
86        }
87
88        pub fn write<E>(&mut self, payload: E) -> Result<(), Error>
89        where
90            E: serde::Serialize,
91        {
92            let ans = match self
93                .stream
94                .write_all(&serde_json::to_vec(&payload).unwrap())
95            {
96                Ok(_) => Ok(()),
97                Err(_) => Err(Error::Write),
98            };
99            drop(self.stream.shutdown(std::net::Shutdown::Write));
100            ans
101        }
102    }
103}
104
105pub mod r#async {
106    use crate::log;
107
108    use super::Error;
109    use futures_lite::io::{AsyncReadExt, AsyncWriteExt};
110
111    use async_net::unix::{UnixListener, UnixStream};
112
113    pub struct UnixListenerWrapper {
114        listener: UnixListener,
115    }
116
117    impl UnixListenerWrapper {
118        pub fn bind<P>(path: P) -> Result<Self, Error>
119        where
120            P: AsRef<std::path::Path>,
121        {
122            drop(std::fs::remove_file(&path));
123            match UnixListener::bind(path) {
124                Ok(listener) => Ok(Self { listener }),
125                Err(e) => Err(Error::UnixSocket(e.to_string())),
126            }
127        }
128
129        pub async fn loop_accept<F>(&self, mut callback: F)
130        where
131            F: AsyncFnMut(UnixStreamWrapper) -> bool,
132        {
133            loop {
134                match self.listener.accept().await {
135                    Err(e) => log::warning!("failed to connect {e:?}"),
136                    Ok((stream, _addr)) => {
137                        let stream = UnixStreamWrapper::new(stream);
138                        let alive = callback(stream).await;
139                        if !alive {
140                            return;
141                        }
142                    }
143                }
144            }
145        }
146    }
147
148    pub struct UnixStreamWrapper {
149        stream: UnixStream,
150    }
151
152    impl UnixStreamWrapper {
153        pub fn new(stream: UnixStream) -> Self {
154            Self { stream }
155        }
156
157        pub async fn read<T>(&mut self) -> Result<T, Error>
158        where
159            T: serde::de::DeserializeOwned,
160        {
161            let mut response = vec![];
162            drop(self.stream.read_to_end(&mut response).await);
163            match serde_json::from_slice(&response) {
164                Ok(response) => Ok(response),
165                Err(_) => Err(Error::Read),
166            }
167        }
168
169        pub async fn write<E>(&mut self, payload: E) -> Result<(), Error>
170        where
171            E: serde::Serialize,
172        {
173            let ans = match self
174                .stream
175                .write_all(&serde_json::to_vec(&payload).unwrap())
176                .await
177            {
178                Ok(_) => Ok(()),
179                Err(_) => Err(Error::Write),
180            };
181            drop(self.stream.shutdown(std::net::Shutdown::Write));
182            ans
183        }
184    }
185}