Skip to main content

Grove/Transport/
WASMTransport.rs

1//! WASM Transport Implementation
2//!
3//! Provides direct communication with WASM modules.
4//! Handles calls to and from WebAssembly instances.
5
6use std::{collections::HashMap, path::PathBuf, sync::Arc};
7
8use async_trait::async_trait;
9use base64::Engine;
10use bytes::Bytes;
11use serde::{Deserialize, Serialize};
12use tokio::sync::RwLock;
13use tracing::{debug, info, instrument};
14
15use crate::{
16	Transport::{
17		Strategy::{TransportStats, TransportStrategy, TransportType},
18		TransportConfig,
19	},
20	WASM::{
21		HostBridge::HostBridgeImpl,
22		MemoryManager::{MemoryLimits, MemoryManagerImpl},
23		Runtime::{WASMConfig, WASMRuntime},
24		WASMStats,
25	},
26};
27
28/// WASM transport for direct module communication
29#[derive(Clone, Debug)]
30pub struct WASMTransportImpl {
31	/// WASM runtime
32	runtime:Arc<WASMRuntime>,
33	/// Memory manager
34	memory_manager:Arc<RwLock<MemoryManagerImpl>>,
35	/// Host bridge for communication
36	bridge:Arc<HostBridgeImpl>,
37	/// Loaded modules
38	modules:Arc<RwLock<HashMap<String, WASMModuleInfo>>>,
39	/// Transport configuration
40	#[allow(dead_code)]
41	config:TransportConfig,
42	/// Connection state
43	connected:Arc<RwLock<bool>>,
44	/// Transport statistics
45	stats:Arc<RwLock<TransportStats>>,
46}
47
48/// Information about a loaded WASM module
49#[derive(Debug, Clone)]
50pub struct WASMModuleInfo {
51	/// Module ID
52	pub id:String,
53	/// Module name (if available)
54	pub name:Option<String>,
55	/// Path to module file
56	pub path:Option<PathBuf>,
57	/// Module loaded timestamp
58	pub loaded_at:u64,
59	/// Function statistics
60	pub function_stats:HashMap<String, FunctionCallStats>,
61}
62
63/// Statistics for function calls
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct FunctionCallStats {
66	/// Number of calls
67	pub call_count:u64,
68	/// Total execution time in microseconds
69	pub total_time_us:u64,
70	/// Last call timestamp
71	pub last_call_at:Option<u64>,
72	/// Number of errors
73	pub error_count:u64,
74}
75
76impl FunctionCallStats {
77	/// Record a successful function call
78	pub fn record_call(&mut self, time_us:u64) {
79		self.call_count += 1;
80		self.total_time_us += time_us;
81		self.last_call_at = Some(
82			std::time::SystemTime::now()
83				.duration_since(std::time::UNIX_EPOCH)
84				.map(|d| d.as_secs())
85				.unwrap_or(0),
86		);
87	}
88
89	/// Record a failed function call
90	pub fn record_error(&mut self) { self.error_count += 1; }
91}
92
93impl Default for FunctionCallStats {
94	fn default() -> Self { Self { call_count:0, total_time_us:0, last_call_at:None, error_count:0 } }
95}
96
97impl WASMTransportImpl {
98	/// Create a new WASM transport with default configuration
99	pub fn new(enable_wasi:bool, memory_limit_mb:u64, max_execution_time_ms:u64) -> anyhow::Result<Self> {
100		let config = WASMConfig::new(memory_limit_mb, max_execution_time_ms, enable_wasi);
101
102		// Create runtime - this would normally be async, but for now we do it
103		// synchronously In production, this would need to be properly awaited
104		let runtime_result = tokio::runtime::Runtime::new()
105			.map_err(|e| anyhow::anyhow!("Failed to create tokio runtime: {}", e))?
106			.block_on(WASMRuntime::new(config.clone()))
107			.map_err(|e| anyhow::anyhow!("Failed to create WASM runtime: {}", e))?;
108		let runtime = Arc::new(runtime_result);
109
110		let memory_limits = MemoryLimits::new(memory_limit_mb, (memory_limit_mb as f64 * 0.75) as u64, 100);
111		let memory_manager = Arc::new(RwLock::new(MemoryManagerImpl::new(memory_limits)));
112		let bridge = Arc::new(HostBridgeImpl::new());
113
114		Ok(Self {
115			runtime,
116			memory_manager,
117			bridge,
118			modules:Arc::new(RwLock::new(HashMap::new())),
119			config:TransportConfig::default(),
120			connected:Arc::new(RwLock::new(true)), // WASM transport is always "connected" locally
121			stats:Arc::new(RwLock::new(TransportStats::default())),
122		})
123	}
124
125	/// Create a new WASM transport with custom configuration
126	pub fn with_config(wasm_config:WASMConfig, transport_config:TransportConfig) -> anyhow::Result<Self> {
127		let runtime_result = tokio::runtime::Runtime::new()
128			.map_err(|e| anyhow::anyhow!("Failed to create tokio runtime: {}", e))?
129			.block_on(WASMRuntime::new(wasm_config.clone()))
130			.map_err(|e| anyhow::anyhow!("Failed to create WASM runtime: {}", e))?;
131		let runtime = Arc::new(runtime_result);
132
133		let memory_limits = MemoryLimits::new(
134			wasm_config.memory_limit_mb,
135			(wasm_config.memory_limit_mb as f64 * 0.75) as u64,
136			100,
137		);
138		let memory_manager = Arc::new(RwLock::new(MemoryManagerImpl::new(memory_limits)));
139		let bridge = Arc::new(HostBridgeImpl::new());
140
141		Ok(Self {
142			runtime,
143			memory_manager,
144			bridge,
145			modules:Arc::new(RwLock::new(HashMap::new())),
146			config:transport_config,
147			connected:Arc::new(RwLock::new(true)),
148			stats:Arc::new(RwLock::new(TransportStats::default())),
149		})
150	}
151
152	/// Get a reference to the WASM runtime
153	pub fn runtime(&self) -> &Arc<WASMRuntime> { &self.runtime }
154
155	/// Get a reference to the memory manager
156	pub fn memory_manager(&self) -> &Arc<RwLock<MemoryManagerImpl>> { &self.memory_manager }
157
158	/// Get a reference to the host bridge
159	pub fn bridge(&self) -> &Arc<HostBridgeImpl> { &self.bridge }
160
161	/// Get all loaded modules
162	pub async fn get_modules(&self) -> HashMap<String, WASMModuleInfo> { self.modules.read().await.clone() }
163
164	/// Get WASM runtime statistics
165	pub async fn get_wasm_stats(&self) -> WASMStats {
166		let memory_manager = self.memory_manager.read().await;
167		let managers = self.modules.read().await;
168
169		WASMStats {
170			modules_loaded:managers.len(),
171			active_instances:managers.len(), // In real implementation, track instances
172			total_memory_mb:memory_manager.current_usage_mb() as u64,
173			total_execution_time_ms:0, // Track from actual calls
174			function_calls:self.stats.read().await.messages_sent,
175		}
176	}
177
178	/// Call a function in a WASM module
179	#[instrument(skip(self, module_id, function_name, args))]
180	pub async fn call_wasm_function(
181		&self,
182		module_id:&str,
183		function_name:&str,
184		args:Vec<Bytes>,
185	) -> anyhow::Result<Bytes> {
186		let start = std::time::Instant::now();
187
188		debug!(
189			"Calling WASM function: {}::{} with {} arguments",
190			module_id,
191			function_name,
192			args.len()
193		);
194
195		let modules = self.modules.read().await;
196		let _module = modules
197			.get(module_id)
198			.ok_or_else(|| anyhow::anyhow!("Module not found: {}", module_id))?;
199
200		// In a real implementation, this would call the actual WASM function
201		// For now, we return a mock response
202		let response = Bytes::new();
203
204		// Update statistics
205		let mut modules_mut = self.modules.write().await;
206		if let Some(module) = modules_mut.get_mut(module_id) {
207			let stats = module.function_stats.entry(function_name.to_string()).or_default();
208			stats.record_call(start.elapsed().as_micros() as u64);
209		}
210
211		drop(modules_mut);
212
213		// Update transport statistics
214		let mut stats = self.stats.write().await;
215		stats.record_sent(args.iter().map(|b| b.len() as u64).sum(), start.elapsed().as_micros() as u64);
216		stats.record_received(response.len() as u64);
217
218		Ok(response)
219	}
220}
221
222#[async_trait]
223impl TransportStrategy for WASMTransportImpl {
224	type Error = WASMTransportError;
225
226	#[instrument(skip(self))]
227	async fn connect(&self) -> Result<(), Self::Error> {
228		info!("WASM transport connecting");
229
230		// WASM transport is always "connected" locally
231		*self.connected.write().await = true;
232
233		info!("WASM transport connected");
234
235		Ok(())
236	}
237
238	#[instrument(skip(self, request))]
239	async fn send(&self, request:&[u8]) -> Result<Vec<u8>, Self::Error> {
240		let start = std::time::Instant::now();
241
242		if !self.is_connected() {
243			return Err(WASMTransportError::NotConnected);
244		}
245
246		debug!("Sending WASM transport request ({} bytes)", request.len());
247
248		// Parse request - it should contain module ID and function name
249		// For simplicity, we use a minimal format: module_id:function_name:base64_args
250		let request_str =
251			std::str::from_utf8(request).map_err(|e| WASMTransportError::InvalidRequest(e.to_string()))?;
252
253		let parts:Vec<&str> = request_str.splitn(3, ':').collect();
254		if parts.len() < 3 {
255			return Err(WASMTransportError::InvalidRequest("Invalid request format".to_string()));
256		}
257
258		let module_id = parts[0];
259		let function_name = parts[1];
260		let args_base64 = parts[2];
261
262		// Decode arguments from base64
263		use base64::engine::general_purpose::STANDARD;
264		let args = vec![Bytes::from(
265			STANDARD
266				.decode(args_base64)
267				.map_err(|e| WASMTransportError::InvalidRequest(e.to_string()))?,
268		)];
269
270		// Call the WASM function
271		let response = self
272			.call_wasm_function(module_id, function_name, args)
273			.await
274			.map_err(|e| WASMTransportError::FunctionCallFailed(e.to_string()))?;
275
276		// Convert response to Vec<u8>
277		let response_vec = response.to_vec();
278
279		let latency_us = start.elapsed().as_micros() as u64;
280
281		debug!("WASM transport request completed in {}µs", latency_us);
282
283		Ok(response_vec)
284	}
285
286	#[instrument(skip(self, data))]
287	async fn send_no_response(&self, data:&[u8]) -> Result<(), Self::Error> {
288		if !self.is_connected() {
289			return Err(WASMTransportError::NotConnected);
290		}
291
292		debug!("Sending WASM transport request without response ({} bytes)", data.len());
293
294		// For fire-and-forget calls, we still execute but ignore the response
295		self.send(data).await?;
296		Ok(())
297	}
298
299	#[instrument(skip(self))]
300	async fn close(&self) -> Result<(), Self::Error> {
301		info!("Closing WASM transport");
302
303		*self.connected.write().await = false;
304
305		info!("WASM transport closed");
306
307		Ok(())
308	}
309
310	fn is_connected(&self) -> bool { self.connected.blocking_read().to_owned() }
311
312	fn transport_type(&self) -> TransportType { TransportType::WASM }
313}
314
315/// WASM transport errors
316#[derive(Debug, thiserror::Error)]
317pub enum WASMTransportError {
318	/// Module not found error
319	#[error("Module not found: {0}")]
320	ModuleNotFound(String),
321
322	/// Function not found error
323	#[error("Function not found: {0}")]
324	FunctionNotFound(String),
325
326	/// Function call failed error
327	#[error("Function call failed: {0}")]
328	FunctionCallFailed(String),
329
330	/// Memory error
331	#[error("Memory error: {0}")]
332	MemoryError(String),
333
334	/// Runtime error
335	#[error("Runtime error: {0}")]
336	RuntimeError(String),
337
338	/// Invalid request error
339	#[error("Invalid request: {0}")]
340	InvalidRequest(String),
341
342	/// Not connected error
343	#[error("Not connected")]
344	NotConnected,
345
346	/// Compilation failed error
347	#[error("Compilation failed: {0}")]
348	CompilationFailed(String),
349
350	/// Timeout error
351	#[error("Timeout")]
352	Timeout,
353}
354
355#[cfg(test)]
356mod tests {
357	use super::*;
358	use crate::Transport::Strategy::TransportStrategy;
359
360	#[test]
361	fn test_wasm_transport_creation() {
362		let result = WASMTransportImpl::new(true, 512, 30000);
363		assert!(result.is_ok());
364		let transport = result.unwrap();
365		// WASM transport should always be connected
366		assert!(transport.is_connected());
367	}
368
369	#[test]
370	fn test_function_call_stats() {
371		let mut stats = FunctionCallStats::default();
372		stats.record_call(100);
373		assert_eq!(stats.call_count, 1);
374		assert_eq!(stats.total_time_us, 100);
375		assert!(stats.last_call_at.is_some());
376	}
377
378	#[tokio::test]
379	async fn test_wasm_transport_not_connected_after_close() {
380		let transport = WASMTransportImpl::new(true, 512, 30000).unwrap();
381		let _:anyhow::Result<()> = transport.close().await.map_err(|e| anyhow::anyhow!(e.to_string()));
382		assert!(!transport.is_connected());
383	}
384
385	#[tokio::test]
386	async fn test_get_wasm_stats() {
387		let transport = WASMTransportImpl::new(true, 512, 30000).unwrap();
388		let stats = transport.get_wasm_stats().await;
389		assert_eq!(stats.modules_loaded, 0);
390		assert_eq!(stats.active_instances, 0);
391	}
392}