refactor(messaging/websocket): 优化 WebSocket 相关配置及命名

This commit is contained in:
2024-06-23 11:30:03 +08:00
parent a208fa59b2
commit 6c10e80d71
8 changed files with 54 additions and 97 deletions

View File

@@ -31,7 +31,7 @@ import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer; import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.socket.server.HandshakeInterceptor;
import top.continew.starter.core.constant.PropertiesConstants; import top.continew.starter.core.constant.PropertiesConstants;
import top.continew.starter.messaging.websocket.core.CurrentUserProvider; import top.continew.starter.messaging.websocket.core.WebSocketClientService;
import top.continew.starter.messaging.websocket.core.WebSocketInterceptor; import top.continew.starter.messaging.websocket.core.WebSocketInterceptor;
import top.continew.starter.messaging.websocket.dao.WebSocketSessionDao; import top.continew.starter.messaging.websocket.dao.WebSocketSessionDao;
import top.continew.starter.messaging.websocket.dao.WebSocketSessionDaoDefaultImpl; import top.continew.starter.messaging.websocket.dao.WebSocketSessionDaoDefaultImpl;
@@ -73,7 +73,7 @@ public class WebSocketAutoConfiguration {
@Bean @Bean
@ConditionalOnMissingBean @ConditionalOnMissingBean
public HandshakeInterceptor handshakeInterceptor() { public HandshakeInterceptor handshakeInterceptor() {
return new WebSocketInterceptor(properties, SpringUtil.getBean(CurrentUserProvider.class)); return new WebSocketInterceptor(properties, SpringUtil.getBean(WebSocketClientService.class));
} }
/** /**
@@ -86,12 +86,12 @@ public class WebSocketAutoConfiguration {
} }
/** /**
* 当前用户 Provider(如不提供,则报错) * WebSocket 客户端服务(如不提供,则报错)
*/ */
@Bean @Bean
@ConditionalOnMissingBean @ConditionalOnMissingBean
public CurrentUserProvider currentUserProvider() { public WebSocketClientService webSocketClientService() {
throw new NoSuchBeanDefinitionException(CurrentUserProvider.class); throw new NoSuchBeanDefinitionException(WebSocketClientService.class);
} }
@PostConstruct @PostConstruct

View File

@@ -52,9 +52,9 @@ public class WebSocketProperties {
private List<String> allowedOrigins = new ArrayList<>(ALL); private List<String> allowedOrigins = new ArrayList<>(ALL);
/** /**
* 当前登录用户 Key * 客户端 ID Key
*/ */
private String currentUserKey = "CURRENT_USER"; private String clientIdKey = "CLIENT_ID";
public boolean isEnabled() { public boolean isEnabled() {
return enabled; return enabled;
@@ -80,11 +80,11 @@ public class WebSocketProperties {
this.allowedOrigins = allowedOrigins; this.allowedOrigins = allowedOrigins;
} }
public String getCurrentUserKey() { public String getClientIdKey() {
return currentUserKey; return clientIdKey;
} }
public void setCurrentUserKey(String currentUserKey) { public void setClientIdKey(String clientIdKey) {
this.currentUserKey = currentUserKey; this.clientIdKey = clientIdKey;
} }
} }

View File

@@ -17,21 +17,20 @@
package top.continew.starter.messaging.websocket.core; package top.continew.starter.messaging.websocket.core;
import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpRequest;
import top.continew.starter.messaging.websocket.model.CurrentUser;
/** /**
* 当前登录用户 Provider * WebSocket 客户端服务
* *
* @author Charles7c * @author Charles7c
* @since 2.1.0 * @since 2.1.0
*/ */
public interface CurrentUserProvider { public interface WebSocketClientService {
/** /**
* 获取当前登录用户 * 获取当前客户端 ID
* *
* @param request 请求对象 * @param request 请求对象
* @return 当前登录用户 * @return 当前客户端 ID
*/ */
CurrentUser getCurrentUser(ServletServerHttpRequest request); String getClientId(ServletServerHttpRequest request);
} }

View File

@@ -22,10 +22,12 @@ import org.slf4j.LoggerFactory;
import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.AbstractWebSocketHandler; import org.springframework.web.socket.handler.TextWebSocketHandler;
import top.continew.starter.messaging.websocket.autoconfigure.WebSocketProperties; import top.continew.starter.messaging.websocket.autoconfigure.WebSocketProperties;
import top.continew.starter.messaging.websocket.dao.WebSocketSessionDao; import top.continew.starter.messaging.websocket.dao.WebSocketSessionDao;
import java.io.IOException;
/** /**
* WebSocket 处理器 * WebSocket 处理器
* *
@@ -33,7 +35,7 @@ import top.continew.starter.messaging.websocket.dao.WebSocketSessionDao;
* @author Charles7c * @author Charles7c
* @since 2.1.0 * @since 2.1.0
*/ */
public class WebSocketHandler extends AbstractWebSocketHandler { public class WebSocketHandler extends TextWebSocketHandler {
private static final Logger log = LoggerFactory.getLogger(WebSocketHandler.class); private static final Logger log = LoggerFactory.getLogger(WebSocketHandler.class);
private final WebSocketProperties webSocketProperties; private final WebSocketProperties webSocketProperties;
@@ -46,26 +48,41 @@ public class WebSocketHandler extends AbstractWebSocketHandler {
@Override @Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
log.info("WebSocket receive message. sessionId: {}, message: {}.", session.getId(), message.getPayload()); String clientId = this.getClientId(session);
log.info("WebSocket receive message. clientId: {}, message: {}.", clientId, message.getPayload());
super.handleTextMessage(session, message); super.handleTextMessage(session, message);
} }
@Override @Override
public void afterConnectionEstablished(WebSocketSession session) { public void afterConnectionEstablished(WebSocketSession session) {
String sessionKey = Convert.toStr(session.getAttributes().get(webSocketProperties.getCurrentUserKey())); String clientId = this.getClientId(session);
webSocketSessionDao.add(sessionKey, session); webSocketSessionDao.add(clientId, session);
log.info("WebSocket connect successfully. sessionKey: {}.", sessionKey); log.info("WebSocket client connect successfully. clientId: {}.", clientId);
} }
@Override @Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) { public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {
String sessionKey = Convert.toStr(session.getAttributes().get(webSocketProperties.getCurrentUserKey())); String clientId = this.getClientId(session);
webSocketSessionDao.delete(sessionKey); webSocketSessionDao.delete(clientId);
log.info("WebSocket connect closed. sessionKey: {}.", sessionKey); log.info("WebSocket client connect closed. clientId: {}.", clientId);
} }
@Override @Override
public void handleTransportError(WebSocketSession session, Throwable exception) { public void handleTransportError(WebSocketSession session, Throwable exception) throws IOException {
log.error("WebSocket transport error. sessionId: {}.", session.getId(), exception); String clientId = this.getClientId(session);
if (session.isOpen()) {
session.close();
}
webSocketSessionDao.delete(clientId);
}
/**
* 获取客户端 ID
*
* @param session 会话
* @return 客户端 ID
*/
private String getClientId(WebSocketSession session) {
return Convert.toStr(session.getAttributes().get(webSocketProperties.getClientIdKey()));
} }
} }

View File

@@ -22,7 +22,6 @@ import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor; import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
import top.continew.starter.messaging.websocket.autoconfigure.WebSocketProperties; import top.continew.starter.messaging.websocket.autoconfigure.WebSocketProperties;
import top.continew.starter.messaging.websocket.model.CurrentUser;
import java.util.Map; import java.util.Map;
@@ -36,11 +35,12 @@ import java.util.Map;
public class WebSocketInterceptor extends HttpSessionHandshakeInterceptor { public class WebSocketInterceptor extends HttpSessionHandshakeInterceptor {
private final WebSocketProperties webSocketProperties; private final WebSocketProperties webSocketProperties;
private final CurrentUserProvider currentUserProvider; private final WebSocketClientService webSocketClientService;
public WebSocketInterceptor(WebSocketProperties webSocketProperties, CurrentUserProvider currentUserProvider) { public WebSocketInterceptor(WebSocketProperties webSocketProperties,
WebSocketClientService webSocketClientService) {
this.webSocketProperties = webSocketProperties; this.webSocketProperties = webSocketProperties;
this.currentUserProvider = currentUserProvider; this.webSocketClientService = webSocketClientService;
} }
@Override @Override
@@ -48,8 +48,8 @@ public class WebSocketInterceptor extends HttpSessionHandshakeInterceptor {
ServerHttpResponse response, ServerHttpResponse response,
WebSocketHandler wsHandler, WebSocketHandler wsHandler,
Map<String, Object> attributes) { Map<String, Object> attributes) {
CurrentUser currentUser = currentUserProvider.getCurrentUser((ServletServerHttpRequest)request); String clientId = webSocketClientService.getClientId((ServletServerHttpRequest)request);
attributes.put(webSocketProperties.getCurrentUserKey(), currentUser.getUserId()); attributes.put(webSocketProperties.getClientIdKey(), clientId);
return true; return true;
} }

View File

@@ -1,58 +0,0 @@
/*
* Copyright (c) 2022-present Charles7c Authors. All Rights Reserved.
* <p>
* Licensed under the GNU LESSER GENERAL PUBLIC LICENSE 3.0;
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.gnu.org/licenses/lgpl.html
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package top.continew.starter.messaging.websocket.model;
import java.io.Serial;
import java.io.Serializable;
/**
* 当前登录用户信息
*
* @author Charles7c
* @since 2.1.0
*/
public class CurrentUser implements Serializable {
@Serial
private static final long serialVersionUID = 1L;
/**
* 用户 ID
*/
private String userId;
/**
* 扩展字段
*/
private Object extend;
public String getUserId() {
return userId;
}
public void setUserId(String userId) {
this.userId = userId;
}
public Object getExtend() {
return extend;
}
public void setExtend(Object extend) {
this.extend = extend;
}
}

View File

@@ -44,11 +44,11 @@ public class WebSocketUtils {
/** /**
* 发送消息 * 发送消息
* *
* @param sessionKey 会话 Key * @param clientId 客户端 ID
* @param message 消息内容 * @param message 消息内容
*/ */
public static void sendMessage(String sessionKey, String message) { public static void sendMessage(String clientId, String message) {
WebSocketSession session = SESSION_DAO.get(sessionKey); WebSocketSession session = SESSION_DAO.get(clientId);
sendMessage(session, message); sendMessage(session, message);
} }

View File

@@ -32,7 +32,6 @@ import org.springframework.web.bind.MethodArgumentNotValidException;
import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.RestControllerAdvice; import org.springframework.web.bind.annotation.RestControllerAdvice;
import org.springframework.web.method.annotation.MethodArgumentTypeMismatchException; import org.springframework.web.method.annotation.MethodArgumentTypeMismatchException;
import org.springframework.web.multipart.MaxUploadSizeExceededException;
import org.springframework.web.multipart.MultipartException; import org.springframework.web.multipart.MultipartException;
import top.continew.starter.core.constant.StringConstants; import top.continew.starter.core.constant.StringConstants;
import top.continew.starter.core.exception.BadRequestException; import top.continew.starter.core.exception.BadRequestException;