Files
RhythmicWallpaper/IpcLibrary/Services/MessageRouter.cs

109 lines
4.1 KiB
C#

using IpcLibrary.Core;
using IpcLibrary.Core.Exceptions;
using System;
using System.Collections.Concurrent;
namespace IpcLibrary.Services {
/// <summary>
/// 消息路由器
/// </summary>
public class MessageRouter {
private readonly ConcurrentDictionary<string, TaskCompletionSource<IPCMessage>> _pendingRequests;
private readonly MethodCallHandler _methodCallHandler;
private readonly Timer _timeoutTimer;
public MessageRouter(MethodCallHandler methodCallHandler) {
_pendingRequests = new ConcurrentDictionary<string, TaskCompletionSource<IPCMessage>>();
_methodCallHandler = methodCallHandler ?? throw new ArgumentNullException(nameof(methodCallHandler));
_timeoutTimer = new Timer(CheckTimeouts, null, TimeSpan.FromSeconds(10), TimeSpan.FromSeconds(10));
}
public async Task<IPCMessage> SendRequestAsync(IPCMessage request, Func<IPCMessage, Task<bool>> sendAction, TimeSpan timeout) {
var tcs = new TaskCompletionSource<IPCMessage>();
_pendingRequests[request.Id] = tcs;
try {
if (!await sendAction(request)) {
_pendingRequests.TryRemove(request.Id, out _);
throw new ConnectionException("发送请求失败");
}
using var timeoutCts = new CancellationTokenSource(timeout);
timeoutCts.Token.Register(() => tcs.TrySetException(new Core.Exceptions.TimeoutException("请求超时")));
return await tcs.Task;
} finally {
_pendingRequests.TryRemove(request.Id, out _);
}
}
public async Task<IPCMessage> HandleMessageAsync(IPCMessage message) {
switch (message.Type) {
case MessageType.Request:
return await HandleRequestAsync(message);
case MessageType.Response:
HandleResponse(message);
return null;
case MessageType.Notification:
await HandleNotificationAsync(message);
return null;
default:
return null;
}
}
private async Task<IPCMessage> HandleRequestAsync(IPCMessage request) {
try {
if (request.Parameters?.Length > 0 && request.Parameters[0] is MethodCallRequest methodCall) {
var response = await _methodCallHandler.HandleMethodCallAsync(methodCall);
return new IPCMessage {
Id = Guid.NewGuid().ToString(),
Type = MessageType.Response,
Result = response,
SourceProcessId = request.TargetProcessId,
TargetProcessId = request.SourceProcessId
};
}
} catch (Exception ex) {
return new IPCMessage {
Id = Guid.NewGuid().ToString(),
Type = MessageType.Error,
Error = ex.Message,
SourceProcessId = request.TargetProcessId,
TargetProcessId = request.SourceProcessId
};
}
return null;
}
private void HandleResponse(IPCMessage response) {
if (_pendingRequests.TryGetValue(response.Id, out var tcs)) {
tcs.SetResult(response);
}
}
private async Task HandleNotificationAsync(IPCMessage notification) {
// 通知类型消息的处理逻辑
await Task.CompletedTask;
}
private void CheckTimeouts(object state) {
var expiredRequests = _pendingRequests.Where(kvp =>
DateTime.UtcNow - DateTime.MinValue > TimeSpan.FromMinutes(5)) // 简化的超时检查
.ToList();
foreach (var kvp in expiredRequests) {
if (_pendingRequests.TryRemove(kvp.Key, out var tcs)) {
tcs.TrySetException(new Core.Exceptions.TimeoutException("请求超时"));
}
}
}
}
}