1use std::{collections::HashMap, sync::Arc};
8
9use anyhow::Result;
10use bytes::Bytes;
11use serde::{Serialize, de::DeserializeOwned};
12use tokio::sync::{RwLock, mpsc, oneshot};
13use tracing::{debug, instrument, warn};
14#[allow(unused_imports)]
15use wasmtime::{Caller, Extern, Func, Linker, Store};
16
17#[derive(Debug, thiserror::Error)]
19pub enum BridgeError {
20 #[error("Function not found: {0}")]
22 FunctionNotFound(String),
23
24 #[error("Invalid function signature: {0}")]
26 InvalidSignature(String),
27
28 #[error("Serialization failed: {0}")]
30 SerializationError(String),
31
32 #[error("Deserialization failed: {0}")]
34 DeserializationError(String),
35
36 #[error("Host function error: {0}")]
38 HostFunctionError(String),
39
40 #[error("Communication timeout")]
42 Timeout,
43
44 #[error("Bridge closed")]
46 BridgeClosed,
47}
48
49pub type BridgeResult<T> = Result<T, BridgeError>;
51
52#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
54pub struct FunctionSignature {
55 pub name:String,
57 pub param_types:Vec<ParamType>,
59 pub return_type:Option<ReturnType>,
61 pub is_async:bool,
63}
64
65#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
67pub enum ParamType {
68 I32,
70 I64,
72 F32,
74 F64,
76 Ptr,
78 Len,
80}
81
82#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
84pub enum ReturnType {
85 I32,
87 I64,
89 F32,
91 F64,
93 Void,
95}
96
97#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
99pub struct HostMessage {
100 pub message_id:String,
102 pub function:String,
104 pub args:Vec<Bytes>,
106 pub callback_token:Option<u64>,
108}
109
110#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
112pub struct HostResponse {
113 pub message_id:String,
115 pub success:bool,
117 pub data:Option<Bytes>,
119 pub error:Option<String>,
121}
122
123#[derive(Clone)]
125pub struct AsyncCallback {
126 sender:Arc<tokio::sync::Mutex<Option<tokio::sync::oneshot::Sender<HostResponse>>>>,
128 message_id:String,
130}
131
132impl std::fmt::Debug for AsyncCallback {
133 fn fmt(&self, f:&mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134 f.debug_struct("AsyncCallback").field("message_id", &self.message_id).finish()
135 }
136}
137
138impl AsyncCallback {
139 pub async fn send(self, response:HostResponse) -> Result<()> {
141 let mut sender_opt = self.sender.lock().await;
142 if let Some(sender) = sender_opt.take() {
143 sender.send(response).map_err(|_| BridgeError::BridgeClosed)?;
144 Ok(())
145 } else {
146 Err(BridgeError::BridgeClosed.into())
147 }
148 }
149}
150
151#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
153pub struct WASMMessage {
154 pub function:String,
156 pub args:Vec<Bytes>,
158}
159
160pub type HostFunctionCallback = fn(Vec<Bytes>) -> Result<Bytes>;
162
163pub type AsyncHostFunctionCallback =
165 fn(Vec<Bytes>) -> Box<dyn std::future::Future<Output = Result<Bytes>> + Send + Unpin>;
166
167#[derive(Debug)]
169pub struct HostFunction {
170 pub name:String,
172 pub signature:FunctionSignature,
174 #[allow(dead_code)]
176 pub callback:Option<HostFunctionCallback>,
177 #[allow(dead_code)]
179 pub async_callback:Option<AsyncHostFunctionCallback>,
180}
181
182#[derive(Debug)]
184pub struct HostBridgeImpl {
185 host_functions:Arc<RwLock<HashMap<String, HostFunction>>>,
187 wasm_to_host_rx:mpsc::UnboundedReceiver<WASMMessage>,
189 host_to_wasm_tx:mpsc::UnboundedSender<WASMMessage>,
191 async_callbacks:Arc<RwLock<HashMap<u64, AsyncCallback>>>,
193 next_callback_token:Arc<std::sync::atomic::AtomicU64>,
195}
196
197impl HostBridgeImpl {
198 pub fn new() -> Self {
200 let (_wasm_to_host_tx, wasm_to_host_rx) = mpsc::unbounded_channel();
201 let (host_to_wasm_tx, host_to_wasm_rx) = mpsc::unbounded_channel();
202
203 drop(host_to_wasm_rx);
206
207 Self {
208 host_functions:Arc::new(RwLock::new(HashMap::new())),
209 wasm_to_host_rx,
210 host_to_wasm_tx,
211 async_callbacks:Arc::new(RwLock::new(HashMap::new())),
212 next_callback_token:Arc::new(std::sync::atomic::AtomicU64::new(0)),
213 }
214 }
215
216 #[instrument(skip(self, callback))]
218 pub async fn register_host_function(
219 &self,
220 name:&str,
221 signature:FunctionSignature,
222 callback:HostFunctionCallback,
223 ) -> BridgeResult<()> {
224 debug!("Registering host function: {}", name);
225
226 let mut functions = self.host_functions.write().await;
227
228 if functions.contains_key(name) {
229 warn!("Host function already registered: {}", name);
230 }
231
232 functions.insert(
233 name.to_string(),
234 HostFunction { name:name.to_string(), signature, callback:Some(callback), async_callback:None },
235 );
236
237 debug!("Host function registered successfully: {}", name);
238 Ok(())
239 }
240
241 #[instrument(skip(self, callback))]
243 pub async fn register_async_host_function(
244 &self,
245 name:&str,
246 signature:FunctionSignature,
247 callback:AsyncHostFunctionCallback,
248 ) -> BridgeResult<()> {
249 debug!("Registering async host function: {}", name);
250
251 let mut functions = self.host_functions.write().await;
252
253 functions.insert(
254 name.to_string(),
255 HostFunction { name:name.to_string(), signature, callback:None, async_callback:Some(callback) },
256 );
257
258 debug!("Async host function registered successfully: {}", name);
259 Ok(())
260 }
261
262 #[instrument(skip(self, args))]
264 pub async fn call_host_function(&self, function_name:&str, args:Vec<Bytes>) -> BridgeResult<Bytes> {
265 debug!("Calling host function: {}", function_name);
266
267 let functions = self.host_functions.read().await;
268 let func = functions
269 .get(function_name)
270 .ok_or_else(|| BridgeError::FunctionNotFound(function_name.to_string()))?;
271
272 if let Some(callback) = func.callback {
273 let result =
275 callback(args).map_err(|e| BridgeError::HostFunctionError(format!("{}: {}", function_name, e)))?;
276 debug!("Host function call completed: {}", function_name);
277 Ok(result)
278 } else if let Some(async_callback) = func.async_callback {
279 let future = async_callback(args);
281 let result = future
282 .await
283 .map_err(|e| BridgeError::HostFunctionError(format!("{}: {}", function_name, e)))?;
284 debug!("Async host function call completed: {}", function_name);
285 Ok(result)
286 } else {
287 Err(BridgeError::FunctionNotFound(format!(
288 "No callback for function: {}",
289 function_name
290 )))
291 }
292 }
293
294 #[instrument(skip(self, message))]
296 pub async fn send_to_wasm(&self, message:WASMMessage) -> BridgeResult<()> {
297 let function_name = message.function.clone();
298 self.host_to_wasm_tx.send(message).map_err(|_| BridgeError::BridgeClosed)?;
299 debug!("Message sent to WASM: {}", function_name);
300 Ok(())
301 }
302
303 pub async fn receive_from_wasm(&mut self) -> Option<WASMMessage> { self.wasm_to_host_rx.recv().await }
305
306 #[instrument(skip(self))]
308 pub async fn create_async_callback(&self, message_id:String) -> (AsyncCallback, u64) {
309 let token = self.next_callback_token.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
310 let (tx, _rx) = oneshot::channel();
311
312 let callback = AsyncCallback {
314 sender:Arc::new(tokio::sync::Mutex::new(Some(tx))),
315 message_id:message_id.clone(),
316 };
317
318 self.async_callbacks.write().await.insert(token, callback.clone());
319
320 (callback, token)
321 }
322
323 #[instrument(skip(self))]
325 pub async fn get_callback(&self, token:u64) -> Option<AsyncCallback> {
326 self.async_callbacks.write().await.remove(&token)
327 }
328
329 pub async fn get_host_functions(&self) -> Vec<String> { self.host_functions.read().await.keys().cloned().collect() }
331
332 #[instrument(skip(self))]
334 pub async fn unregister_host_function(&self, name:&str) -> bool {
335 let mut functions = self.host_functions.write().await;
336 let removed = functions.remove(name).is_some();
337 if removed {
338 debug!("Host function unregistered: {}", name);
339 }
340 removed
341 }
342
343 pub async fn clear(&self) {
345 debug!("Clearing all registered host functions");
346 self.host_functions.write().await.clear();
347 self.async_callbacks.write().await.clear();
348 }
349}
350
351impl Default for HostBridgeImpl {
352 fn default() -> Self { Self::new() }
353}
354
355pub fn serialize_to_bytes<T:Serialize>(data:&T) -> Result<Bytes> {
357 serde_json::to_vec(data)
358 .map(Bytes::from)
359 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
360}
361
362pub fn deserialize_from_bytes<T:DeserializeOwned>(bytes:&Bytes) -> Result<T> {
364 serde_json::from_slice(bytes).map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))
365}
366
367pub fn marshal_args(args:Vec<Bytes>) -> Result<Vec<wasmtime::Val>> {
369 args.iter()
370 .map(|bytes| {
371 let value:serde_json::Value = serde_json::from_slice(bytes)?;
372 match value {
373 serde_json::Value::Number(n) => {
374 if let Some(i) = n.as_i64() {
375 Ok(wasmtime::Val::I32(i as i32))
376 } else if let Some(f) = n.as_f64() {
377 Ok(wasmtime::Val::F64(f.to_bits()))
378 } else {
379 Err(anyhow::anyhow!("Invalid number value"))
380 }
381 },
382 _ => Err(anyhow::anyhow!("Unsupported argument type")),
383 }
384 })
385 .collect()
386}
387
388pub fn unmarshal_return(val:wasmtime::Val) -> Result<Bytes> {
390 match val {
391 wasmtime::Val::I32(i) => {
392 let json = serde_json::to_string(&i)?;
393 Ok(Bytes::from(json))
394 },
395 wasmtime::Val::I64(i) => {
396 let json = serde_json::to_string(&i)?;
397 Ok(Bytes::from(json))
398 },
399 wasmtime::Val::F32(f) => {
400 let json = serde_json::to_string(&f)?;
401 Ok(Bytes::from(json))
402 },
403 wasmtime::Val::F64(f) => {
404 let json = serde_json::to_string(&f)?;
405 Ok(Bytes::from(json))
406 },
407 _ => Err(anyhow::anyhow!("Unsupported return type")),
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn test_function_signature_creation() {
417 let signature = FunctionSignature {
418 name:"test_func".to_string(),
419 param_types:vec![ParamType::I32, ParamType::Ptr],
420 return_type:Some(ReturnType::I32),
421 is_async:false,
422 };
423
424 assert_eq!(signature.name, "test_func");
425 assert_eq!(signature.param_types.len(), 2);
426 }
427
428 #[tokio::test]
429 async fn test_host_bridge_creation() {
430 let bridge = HostBridgeImpl::new();
431 assert_eq!(bridge.get_host_functions().await.len(), 0);
432 }
433
434 #[tokio::test]
435 async fn test_register_host_function() {
436 let bridge = HostBridgeImpl::new();
437
438 let signature = FunctionSignature {
439 name:"echo".to_string(),
440 param_types:vec![ParamType::I32],
441 return_type:Some(ReturnType::I32),
442 is_async:false,
443 };
444
445 let result = bridge
446 .register_host_function("echo", signature, |args| Ok(args[0].clone()))
447 .await;
448
449 assert!(result.is_ok());
450 assert_eq!(bridge.get_host_functions().await.len(), 1);
451 }
452
453 #[test]
454 fn test_serialize_deserialize() {
455 let data = vec![1, 2, 3, 4, 5];
456 let bytes = serialize_to_bytes(&data).unwrap();
457 let recovered:Vec<i32> = deserialize_from_bytes(&bytes).unwrap();
458 assert_eq!(data, recovered);
459 }
460
461 #[test]
462 fn test_marshal_unmarshal() {
463 let args = vec![serialize_to_bytes(&42i32).unwrap(), serialize_to_bytes(&3.14f64).unwrap()];
464
465 let marshaled = marshal_args(args);
467 assert!(marshaled.is_ok());
468 }
469}