Grove/Transport/
gRPCTransport.rs
1#![allow(non_snake_case, non_camel_case_types, non_upper_case_globals)]
2use std::sync::Arc;
8
9use async_trait::async_trait;
10use tokio::sync::RwLock;
11use tonic::transport::{Channel, Endpoint};
12use tracing::{debug, info, instrument};
13
14use crate::Transport::{
15 Strategy::{TransportStats, TransportStrategy, TransportType},
16 TransportConfig,
17};
18
19#[derive(Clone, Debug)]
21pub struct gRPCTransport {
22 Endpoint: String,
24 Channel: Arc<RwLock<Option<Channel>>>,
26 Configuration: TransportConfig,
28 Connected: Arc<RwLock<bool>>,
30 Statistics: Arc<RwLock<TransportStats>>,
32}
33
34impl gRPCTransport {
35 pub fn New(Address: &str) -> anyhow::Result<Self> {
37 Ok(Self {
38 Endpoint: Address.to_string(),
39 Channel: Arc::new(RwLock::new(None)),
40 Configuration: TransportConfig::default(),
41 Connected: Arc::new(RwLock::new(false)),
42 Statistics: Arc::new(RwLock::new(TransportStats::default())),
43 })
44 }
45
46 pub fn WithConfiguration(
48 Address: &str,
49 Configuration: TransportConfig,
50 ) -> anyhow::Result<Self> {
51 Ok(Self {
52 Endpoint: Address.to_string(),
53 Channel: Arc::new(RwLock::new(None)),
54 Configuration,
55 Connected: Arc::new(RwLock::new(false)),
56 Statistics: Arc::new(RwLock::new(TransportStats::default())),
57 })
58 }
59
60 pub fn Address(&self) -> &str { &self.Endpoint }
62
63 pub async fn GetChannel(&self) -> anyhow::Result<Channel> {
65 self.Channel
66 .read()
67 .await
68 .as_ref()
69 .cloned()
70 .ok_or_else(|| anyhow::anyhow!("gRPC channel not connected"))
71 }
72
73 pub async fn Statistics(&self) -> TransportStats { self.Statistics.read().await.clone() }
75
76 fn BuildEndpoint(&self) -> anyhow::Result<Endpoint> {
78 let EndpointValue = Endpoint::from_shared(self.Endpoint.clone())?
79 .timeout(self.Configuration.ConnectionTimeout)
80 .connect_timeout(self.Configuration.ConnectionTimeout)
81 .tcp_keepalive(Some(self.Configuration.KeepaliveInterval));
82 Ok(EndpointValue)
83 }
84}
85
86#[async_trait]
87impl TransportStrategy for gRPCTransport {
88 type Error = gRPCTransportError;
89
90 #[instrument(skip(self))]
91 async fn connect(&self) -> Result<(), Self::Error> {
92 info!("Connecting to gRPC endpoint: {}", self.Endpoint);
93
94 let EndpointValue = self
95 .BuildEndpoint()
96 .map_err(|E| gRPCTransportError::ConnectionFailed(E.to_string()))?;
97
98 let ChannelValue = EndpointValue
99 .connect()
100 .await
101 .map_err(|E| gRPCTransportError::ConnectionFailed(E.to_string()))?;
102
103 *self.Channel.write().await = Some(ChannelValue);
104 *self.Connected.write().await = true;
105
106 info!("gRPC connection established: {}", self.Endpoint);
107 Ok(())
108 }
109
110 #[instrument(skip(self, request))]
111 async fn send(&self, request: &[u8]) -> Result<Vec<u8>, Self::Error> {
112 let Start = std::time::Instant::now();
113
114 if !self.is_connected() {
115 return Err(gRPCTransportError::NotConnected);
116 }
117
118 debug!("Sending gRPC request ({} bytes)", request.len());
119
120 let Response: Vec<u8> = vec![];
121 let LatencyMicroseconds = Start.elapsed().as_micros() as u64;
122
123 let mut Stats = self.Statistics.write().await;
124 Stats.record_sent(request.len() as u64, LatencyMicroseconds);
125 Stats.record_received(Response.len() as u64);
126
127 debug!("gRPC request completed in {}µs", LatencyMicroseconds);
128 Ok(Response)
129 }
130
131 #[instrument(skip(self, data))]
132 async fn send_no_response(&self, data: &[u8]) -> Result<(), Self::Error> {
133 if !self.is_connected() {
134 return Err(gRPCTransportError::NotConnected);
135 }
136
137 debug!("Sending gRPC notification ({} bytes)", data.len());
138
139 let mut Stats = self.Statistics.write().await;
140 Stats.record_sent(data.len() as u64, 0);
141 Ok(())
142 }
143
144 #[instrument(skip(self))]
145 async fn close(&self) -> Result<(), Self::Error> {
146 info!("Closing gRPC connection: {}", self.Endpoint);
147 *self.Channel.write().await = None;
148 *self.Connected.write().await = false;
149 info!("gRPC connection closed: {}", self.Endpoint);
150 Ok(())
151 }
152
153 fn is_connected(&self) -> bool { *self.Connected.blocking_read() }
154
155 fn transport_type(&self) -> TransportType { TransportType::gRPC }
156}
157
158#[derive(Debug, thiserror::Error)]
160pub enum gRPCTransportError {
161 #[error("Connection failed: {0}")]
163 ConnectionFailed(String),
164 #[error("Send failed: {0}")]
166 SendFailed(String),
167 #[error("Receive failed: {0}")]
169 ReceiveFailed(String),
170 #[error("Not connected")]
172 NotConnected,
173 #[error("Timeout")]
175 Timeout,
176 #[error("gRPC error: {0}")]
178 Error(String),
179}
180
181impl From<tonic::transport::Error> for gRPCTransportError {
182 fn from(Error: tonic::transport::Error) -> Self {
183 gRPCTransportError::ConnectionFailed(Error.to_string())
184 }
185}
186
187impl From<tonic::Status> for gRPCTransportError {
188 fn from(Status: tonic::Status) -> Self { gRPCTransportError::Error(Status.to_string()) }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn TestgRPCTransportCreation() {
197 let Result = gRPCTransport::New("127.0.0.1:50050");
198 assert!(Result.is_ok());
199 let Transport = Result.unwrap();
200 assert_eq!(Transport.Address(), "127.0.0.1:50050");
201 }
202
203 #[tokio::test]
204 async fn TestgRPCTransportNotConnected() {
205 let Transport = gRPCTransport::New("127.0.0.1:50050").unwrap();
206 assert!(!Transport.is_connected());
207 }
208}