1use std::future::Future;
6
7use std::pin::Pin;
8
9use futures::{future, FutureExt};
10use http_body_util::{BodyExt, Limited};
11use hyper::header;
12use jsonrpsee::{
13 core::BoxError,
14 server::{HttpBody, HttpRequest, HttpResponse},
15};
16use jsonrpsee_types::ErrorObject;
17use serde::{Deserialize, Serialize};
18use tower::Service;
19
20use super::cookie::Cookie;
21
22use base64::{engine::general_purpose::STANDARD, Engine as _};
23
24#[derive(Clone, Debug)]
56pub struct HttpRequestMiddleware<S> {
57 service: S,
58 cookie: Option<Cookie>,
59 max_request_body_size: usize,
60}
61
62impl<S> HttpRequestMiddleware<S> {
63 pub fn new(service: S, cookie: Option<Cookie>, max_request_body_size: usize) -> Self {
65 Self {
66 service,
67 cookie,
68 max_request_body_size,
69 }
70 }
71
72 pub fn check_credentials(&self, headers: &header::HeaderMap) -> bool {
74 self.cookie.as_ref().is_none_or(|internal_cookie| {
75 headers
76 .get(header::AUTHORIZATION)
77 .and_then(|auth_header| auth_header.to_str().ok())
78 .and_then(|auth_header| auth_header.split_whitespace().nth(1))
79 .and_then(|encoded| STANDARD.decode(encoded).ok())
80 .and_then(|decoded| String::from_utf8(decoded).ok())
81 .and_then(|request_cookie| request_cookie.split(':').nth(1).map(String::from))
82 .is_some_and(|passwd| internal_cookie.authenticate(passwd))
83 })
84 }
85
86 pub fn insert_or_replace_content_type_header(headers: &mut header::HeaderMap) {
108 if !headers.contains_key(header::CONTENT_TYPE)
109 || headers
110 .get(header::CONTENT_TYPE)
111 .filter(|value| {
112 value
113 .to_str()
114 .ok()
115 .unwrap_or_default()
116 .starts_with("text/plain")
117 })
118 .is_some()
119 {
120 headers.insert(
121 header::CONTENT_TYPE,
122 header::HeaderValue::from_static("application/json"),
123 );
124 }
125 }
126
127 async fn request_to_json_rpc_2(
129 request: HttpRequest<HttpBody>,
130 max_request_body_size: usize,
131 ) -> Result<(JsonRpcVersion, HttpRequest<HttpBody>), BoxError> {
132 let (parts, body) = request.into_parts();
133 let bytes = Limited::new(body, max_request_body_size)
134 .collect()
135 .await?
136 .to_bytes();
137 let (version, bytes) =
138 if let Ok(request) = serde_json::from_slice::<'_, JsonRpcRequest>(bytes.as_ref()) {
139 let version = request.version();
140 if matches!(version, JsonRpcVersion::Unknown) {
141 (version, bytes)
142 } else {
143 (
144 version,
145 serde_json::to_vec(&request.into_2()).expect("valid").into(),
146 )
147 }
148 } else {
149 (JsonRpcVersion::Unknown, bytes)
150 };
151 Ok((
152 version,
153 HttpRequest::from_parts(parts, HttpBody::from(bytes.as_ref().to_vec())),
154 ))
155 }
156 async fn response_from_json_rpc_2(
158 version: JsonRpcVersion,
159 response: HttpResponse<HttpBody>,
160 ) -> Result<HttpResponse<HttpBody>, BoxError> {
161 let (parts, body) = response.into_parts();
162 let bytes = body.collect().await?.to_bytes();
163 let bytes =
164 if let Ok(response) = serde_json::from_slice::<'_, JsonRpcResponse>(bytes.as_ref()) {
165 serde_json::to_vec(&response.into_version(version))
166 .expect("valid")
167 .into()
168 } else {
169 bytes
170 };
171 Ok(HttpResponse::from_parts(
172 parts,
173 HttpBody::from(bytes.as_ref().to_vec()),
174 ))
175 }
176}
177
178#[derive(Clone)]
180pub struct HttpRequestMiddlewareLayer {
181 cookie: Option<Cookie>,
182 max_request_body_size: usize,
183}
184
185impl HttpRequestMiddlewareLayer {
186 pub fn new(cookie: Option<Cookie>, max_request_body_size: usize) -> Self {
188 Self {
189 cookie,
190 max_request_body_size,
191 }
192 }
193}
194
195impl<S> tower::Layer<S> for HttpRequestMiddlewareLayer {
196 type Service = HttpRequestMiddleware<S>;
197
198 fn layer(&self, service: S) -> Self::Service {
199 HttpRequestMiddleware::new(service, self.cookie.clone(), self.max_request_body_size)
200 }
201}
202
203impl<S> Service<HttpRequest<HttpBody>> for HttpRequestMiddleware<S>
204where
205 S: Service<HttpRequest, Response = HttpResponse> + std::clone::Clone + Send + 'static,
206 S::Error: Into<BoxError> + 'static,
207 S::Future: Send + 'static,
208{
209 type Response = S::Response;
210 type Error = BoxError;
211 type Future =
212 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
213
214 fn poll_ready(
215 &mut self,
216 cx: &mut std::task::Context<'_>,
217 ) -> std::task::Poll<Result<(), Self::Error>> {
218 self.service.poll_ready(cx).map_err(Into::into)
219 }
220
221 fn call(&mut self, mut request: HttpRequest<HttpBody>) -> Self::Future {
222 if !self.check_credentials(request.headers_mut()) {
224 let error = ErrorObject::borrowed(401, "unauthenticated method", None);
225 return future::err(BoxError::from(error)).boxed();
227 }
228
229 Self::insert_or_replace_content_type_header(request.headers_mut());
231
232 let mut service = self.service.clone();
233 let max_request_body_size = self.max_request_body_size;
234
235 async move {
236 let (version, request) =
237 Self::request_to_json_rpc_2(request, max_request_body_size).await?;
238 let response = service.call(request).await.map_err(Into::into)?;
239 Self::response_from_json_rpc_2(version, response).await
240 }
241 .boxed()
242 }
243}
244
245#[derive(Clone, Copy, Debug)]
246enum JsonRpcVersion {
247 Bitcoind,
249 Lightwalletd,
252 TwoPointZero,
254 Unknown,
256}
257
258#[derive(Debug, Deserialize, Serialize)]
260struct JsonRpcRequest {
261 #[serde(skip_serializing_if = "Option::is_none")]
262 jsonrpc: Option<String>,
263 method: String,
264 #[serde(skip_serializing_if = "Option::is_none")]
265 params: Option<serde_json::Value>,
266 #[serde(skip_serializing_if = "Option::is_none")]
267 id: Option<serde_json::Value>,
268}
269
270impl JsonRpcRequest {
271 fn version(&self) -> JsonRpcVersion {
272 match (self.jsonrpc.as_deref(), &self.params, &self.id) {
273 (
274 Some("2.0"),
275 _,
276 None
277 | Some(
278 serde_json::Value::Null
279 | serde_json::Value::String(_)
280 | serde_json::Value::Number(_),
281 ),
282 ) => JsonRpcVersion::TwoPointZero,
283 (Some("1.0"), Some(_), Some(_)) => JsonRpcVersion::Lightwalletd,
284 (None, Some(_), Some(_)) => JsonRpcVersion::Bitcoind,
285 _ => JsonRpcVersion::Unknown,
286 }
287 }
288
289 fn into_2(mut self) -> Self {
290 self.jsonrpc = Some("2.0".into());
291 self
292 }
293}
294#[derive(Debug, Deserialize, Serialize)]
296struct JsonRpcResponse {
297 #[serde(skip_serializing_if = "Option::is_none")]
298 jsonrpc: Option<String>,
299 id: serde_json::Value,
300 #[serde(skip_serializing_if = "Option::is_none")]
301 result: Option<Box<serde_json::value::RawValue>>,
302 #[serde(skip_serializing_if = "Option::is_none")]
303 error: Option<serde_json::Value>,
304}
305
306impl JsonRpcResponse {
307 fn into_version(mut self, version: JsonRpcVersion) -> Self {
308 match version {
309 JsonRpcVersion::Bitcoind => {
310 self.jsonrpc = None;
311 self.result = self
312 .result
313 .or_else(|| serde_json::value::to_raw_value(&()).ok());
314 self.error = self.error.or(Some(serde_json::Value::Null));
315 }
316 JsonRpcVersion::Lightwalletd => {
317 self.jsonrpc = Some("1.0".into());
318 self.result = self
319 .result
320 .or_else(|| serde_json::value::to_raw_value(&()).ok());
321 self.error = self.error.or(Some(serde_json::Value::Null));
322 }
323 JsonRpcVersion::TwoPointZero => {
324 assert_eq!(self.jsonrpc.as_deref(), Some("2.0"));
328 if self.error.is_none() {
329 self.result = self
330 .result
331 .or_else(|| serde_json::value::to_raw_value(&()).ok());
332 } else {
333 assert!(self.result.is_none());
334 }
335 }
336 JsonRpcVersion::Unknown => (),
337 }
338 self
339 }
340}