binaryninja/websocket/
client.rs

1use crate::rc::{Ref, RefCountable};
2use crate::string::{BnString, IntoCStr};
3use binaryninjacore_sys::*;
4use std::ffi::{c_char, c_void, CStr};
5use std::marker::PhantomData;
6use std::ops::Deref;
7use std::ptr::NonNull;
8
9pub trait WebsocketClientCallback: Sync + Send {
10    /// Receive a notification that the websocket connection has been connected successfully.
11    ///
12    /// Return `false` if you would like to terminate the connection early.
13    fn connected(&self) -> bool;
14
15    /// Receive a notification that the websocket connection has been terminated.
16    ///
17    /// For implementations, you must call this at the end of the websocket connection lifecycle
18    /// even if you notify the client of an error.
19    fn disconnected(&self);
20
21    /// Receive an error from the websocket connection.
22    ///
23    /// For implementations, you typically write the data to an interior mutable buffer on the instance.
24    fn error(&self, msg: &str);
25
26    /// Receive data from the websocket connection.
27    ///
28    /// For implementations, you typically write the data to an interior mutable buffer on the instance.
29    fn read(&self, data: &[u8]) -> bool;
30}
31
32pub trait WebsocketClient: Sync + Send {
33    /// Called to construct this client object with the given core object.
34    fn from_core(core: Ref<CoreWebsocketClient>) -> Self;
35
36    fn connect<I>(&self, host: &str, headers: I) -> bool
37    where
38        I: IntoIterator<Item = (String, String)>;
39
40    fn write(&self, data: &[u8]) -> bool;
41
42    fn disconnect(&self) -> bool;
43}
44
45/// Represents a live websocket connection.
46///
47/// This manages the lifetime of the callback, ensuring it outlives the connection.
48pub struct ActiveConnection<'a, C: WebsocketClientCallback> {
49    pub client: Ref<CoreWebsocketClient>,
50    _callback: PhantomData<&'a mut C>,
51}
52
53impl<'a, C: WebsocketClientCallback> Deref for ActiveConnection<'a, C> {
54    type Target = CoreWebsocketClient;
55
56    fn deref(&self) -> &Self::Target {
57        &self.client
58    }
59}
60
61impl<'a, C: WebsocketClientCallback> Drop for ActiveConnection<'a, C> {
62    fn drop(&mut self) {
63        self.client.disconnect();
64    }
65}
66
67/// Implements a websocket client.
68///
69/// To connect, use [`Ref<CoreWebsocketClient>::connect`] which will return an [`ActiveConnection`]
70/// which manages the lifecycle of the websocket connection.
71#[repr(transparent)]
72pub struct CoreWebsocketClient {
73    pub(crate) handle: NonNull<BNWebsocketClient>,
74}
75
76impl CoreWebsocketClient {
77    pub(crate) unsafe fn ref_from_raw(handle: NonNull<BNWebsocketClient>) -> Ref<Self> {
78        Ref::new(Self { handle })
79    }
80
81    #[allow(clippy::mut_from_ref)]
82    pub(crate) unsafe fn as_raw(&self) -> &mut BNWebsocketClient {
83        &mut *self.handle.as_ptr()
84    }
85
86    /// Call the connect callback function, forward the callback returned value
87    pub fn notify_connected(&self) -> bool {
88        unsafe { BNNotifyWebsocketClientConnect(self.handle.as_ptr()) }
89    }
90
91    /// Notify the callback function of a disconnect. This must be called at the end of an active
92    /// websocket connection lifecycle, to free resources.
93    ///
94    /// NOTE: This does not actually disconnect, use the [Self::disconnect] function for that.
95    pub fn notify_disconnected(&self) {
96        unsafe { BNNotifyWebsocketClientDisconnect(self.handle.as_ptr()) }
97    }
98
99    /// Call the error callback function, this is not a terminating request you must use
100    /// [`CoreWebsocketClient::notify_disconnected`] to terminate the connection.
101    pub fn notify_error(&self, msg: &str) {
102        let error = msg.to_cstr();
103        unsafe { BNNotifyWebsocketClientError(self.handle.as_ptr(), error.as_ptr()) }
104    }
105
106    /// Call the read callback function, forward the callback returned value.
107    pub fn notify_read(&self, data: &[u8]) -> bool {
108        unsafe {
109            BNNotifyWebsocketClientReadData(
110                self.handle.as_ptr(),
111                data.as_ptr() as *mut _,
112                data.len().try_into().unwrap(),
113            )
114        }
115    }
116
117    pub fn write(&self, data: &[u8]) -> bool {
118        let len = u64::try_from(data.len()).unwrap();
119        unsafe { BNWriteWebsocketClientData(self.as_raw(), data.as_ptr(), len) != 0 }
120    }
121
122    pub fn disconnect(&self) -> bool {
123        unsafe { BNDisconnectWebsocketClient(self.as_raw()) }
124    }
125}
126
127impl Ref<CoreWebsocketClient> {
128    /// Initializes the web socket connection, returning the [`ActiveConnection`], once dropped the
129    /// connection will be disconnected.
130    ///
131    /// Connect to a given url, asynchronously. The connection will be run in a
132    /// separate thread managed by the websocket provider.
133    ///
134    /// Callbacks will be called **on the thread of the connection**, so be sure
135    /// to ExecuteOnMainThread any long-running or gui operations in the callbacks.
136    ///
137    /// If the connection succeeds, [WebsocketClientCallback::connected] will be called. On normal
138    /// termination, [WebsocketClientCallback::disconnected] will be called.
139    ///
140    /// If the connection succeeds but later fails, [`WebsocketClientCallback::error`] will be called
141    /// and shortly thereafter [`WebsocketClientCallback::disconnected`] will be called.
142    ///
143    /// If [`WebsocketClientCallback::connected`] or [`WebsocketClientCallback::read`] return false, the
144    /// connection will be aborted.
145    ///
146    /// * `host` - Full url with scheme, domain, optionally port, and path
147    /// * `headers` - HTTP header keys and values
148    /// * `callback` - Callbacks for various websocket events
149    #[must_use]
150    pub fn connect<'a, I, C>(
151        self,
152        host: &str,
153        headers: I,
154        callbacks: &'a C,
155    ) -> Option<ActiveConnection<'a, C>>
156    where
157        I: IntoIterator<Item = (String, String)>,
158        C: WebsocketClientCallback,
159    {
160        let url = host.to_cstr();
161        let (header_keys, header_values): (Vec<_>, Vec<_>) = headers
162            .into_iter()
163            .map(|(k, v)| (k.to_cstr(), v.to_cstr()))
164            .unzip();
165        let header_keys: Vec<*const c_char> = header_keys.iter().map(|k| k.as_ptr()).collect();
166        let header_values: Vec<*const c_char> = header_values.iter().map(|v| v.as_ptr()).collect();
167        // SAFETY: This context will live for as long as the `ActiveConnection` is alive.
168        // SAFETY: Any subsequent call to BNConnectWebsocketClient will write over the context.
169        let mut output_callbacks = BNWebsocketClientOutputCallbacks {
170            context: callbacks as *const C as *mut C as *mut c_void,
171            connectedCallback: Some(cb_connected::<C>),
172            disconnectedCallback: Some(cb_disconnected::<C>),
173            errorCallback: Some(cb_error::<C>),
174            readCallback: Some(cb_read::<C>),
175        };
176        let success = unsafe {
177            BNConnectWebsocketClient(
178                self.handle.as_ptr(),
179                url.as_ptr(),
180                header_keys.len().try_into().unwrap(),
181                header_keys.as_ptr(),
182                header_values.as_ptr(),
183                &mut output_callbacks,
184            )
185        };
186
187        if success {
188            Some(ActiveConnection {
189                client: self,
190                _callback: PhantomData,
191            })
192        } else {
193            None
194        }
195    }
196}
197
198unsafe impl Sync for CoreWebsocketClient {}
199unsafe impl Send for CoreWebsocketClient {}
200
201impl ToOwned for CoreWebsocketClient {
202    type Owned = Ref<Self>;
203
204    fn to_owned(&self) -> Self::Owned {
205        unsafe { RefCountable::inc_ref(self) }
206    }
207}
208
209unsafe impl RefCountable for CoreWebsocketClient {
210    unsafe fn inc_ref(handle: &Self) -> Ref<Self> {
211        let result = BNNewWebsocketClientReference(handle.as_raw());
212        unsafe { Self::ref_from_raw(NonNull::new(result).unwrap()) }
213    }
214
215    unsafe fn dec_ref(handle: &Self) {
216        BNFreeWebsocketClient(handle.as_raw())
217    }
218}
219
220pub(crate) unsafe extern "C" fn cb_destroy_client<W: WebsocketClient>(ctxt: *mut c_void) {
221    let _ = Box::from_raw(ctxt as *mut W);
222}
223
224pub(crate) unsafe extern "C" fn cb_connect<W: WebsocketClient>(
225    ctxt: *mut c_void,
226    host: *const c_char,
227    header_count: u64,
228    header_keys: *const *const c_char,
229    header_values: *const *const c_char,
230) -> bool {
231    let ctxt: &mut W = &mut *(ctxt as *mut W);
232    let host = CStr::from_ptr(host);
233    // SAFETY BnString and *mut c_char are transparent
234    let header_count = usize::try_from(header_count).unwrap();
235    let header_keys = core::slice::from_raw_parts(header_keys as *const BnString, header_count);
236    let header_values = core::slice::from_raw_parts(header_values as *const BnString, header_count);
237    let header_keys_str = header_keys.iter().map(|s| s.to_string_lossy().to_string());
238    let header_values_str = header_values
239        .iter()
240        .map(|s| s.to_string_lossy().to_string());
241    let header = header_keys_str.zip(header_values_str);
242    ctxt.connect(&host.to_string_lossy(), header)
243}
244
245pub(crate) unsafe extern "C" fn cb_write<W: WebsocketClient>(
246    data: *const u8,
247    len: u64,
248    ctxt: *mut c_void,
249) -> bool {
250    let ctxt: &mut W = &mut *(ctxt as *mut W);
251    let len = usize::try_from(len).unwrap();
252    let data = core::slice::from_raw_parts(data, len);
253    ctxt.write(data)
254}
255
256pub(crate) unsafe extern "C" fn cb_disconnect<W: WebsocketClient>(ctxt: *mut c_void) -> bool {
257    let ctxt: &mut W = &mut *(ctxt as *mut W);
258    ctxt.disconnect()
259}
260
261unsafe extern "C" fn cb_connected<W: WebsocketClientCallback>(ctxt: *mut c_void) -> bool {
262    let ctxt: &mut W = &mut *(ctxt as *mut W);
263    ctxt.connected()
264}
265
266unsafe extern "C" fn cb_disconnected<W: WebsocketClientCallback>(ctxt: *mut c_void) {
267    let ctxt: &mut W = &mut *(ctxt as *mut W);
268    ctxt.disconnected()
269}
270
271unsafe extern "C" fn cb_error<W: WebsocketClientCallback>(msg: *const c_char, ctxt: *mut c_void) {
272    let ctxt: &mut W = &mut *(ctxt as *mut W);
273    let msg = CStr::from_ptr(msg);
274    ctxt.error(&msg.to_string_lossy())
275}
276
277unsafe extern "C" fn cb_read<W: WebsocketClientCallback>(
278    data: *mut u8,
279    len: u64,
280    ctxt: *mut c_void,
281) -> bool {
282    let ctxt: &mut W = &mut *(ctxt as *mut W);
283    let len = usize::try_from(len).unwrap();
284    let data = core::slice::from_raw_parts_mut(data, len);
285    ctxt.read(data)
286}