diff --git a/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/AbstractMyBatisInterceptor.java b/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/AbstractMyBatisInterceptor.java index aeaeea4c..b9b9c8d6 100644 --- a/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/AbstractMyBatisInterceptor.java +++ b/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/AbstractMyBatisInterceptor.java @@ -16,14 +16,23 @@ package top.continew.starter.security.crypto.core; +import cn.hutool.core.text.CharSequenceUtil; import cn.hutool.core.util.ReflectUtil; import cn.hutool.extra.spring.SpringUtil; +import org.apache.ibatis.annotations.Param; +import org.apache.ibatis.mapping.MappedStatement; +import top.continew.starter.core.constant.StringConstants; +import top.continew.starter.core.exception.BaseException; import top.continew.starter.security.crypto.annotation.FieldEncrypt; import top.continew.starter.security.crypto.encryptor.IEncryptor; import top.continew.starter.security.crypto.enums.Algorithm; import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.lang.reflect.Parameter; import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Stream; /** * 字段解密拦截器 @@ -33,13 +42,16 @@ import java.util.*; */ public abstract class AbstractMyBatisInterceptor { + private static final Map, List> CLASS_FIELD_CACHE = new ConcurrentHashMap<>(); + private static final Map> ENCRYPT_PARAM_CACHE = new ConcurrentHashMap<>(); + /** * 获取所有字符串类型、需要加/解密的、有值字段 * * @param obj 对象 * @return 字段列表 */ - public List getEncryptFields(Object obj) { + protected List getEncryptFields(Object obj) { if (null == obj) { return Collections.emptyList(); } @@ -52,11 +64,11 @@ public abstract class AbstractMyBatisInterceptor { * @param clazz 类型对象 * @return 字段列表 */ - public List getEncryptFields(Class clazz) { - return Arrays.stream(ReflectUtil.getFields(clazz)) + protected List getEncryptFields(Class clazz) { + return CLASS_FIELD_CACHE.computeIfAbsent(clazz, key -> Arrays.stream(ReflectUtil.getFields(clazz)) .filter(field -> String.class.equals(field.getType())) .filter(field -> null != field.getAnnotation(FieldEncrypt.class)) - .toList(); + .toList()); } /** @@ -65,7 +77,7 @@ public abstract class AbstractMyBatisInterceptor { * @param fieldEncrypt 字段加密注解 * @return 加/解密处理器 */ - public IEncryptor getEncryptor(FieldEncrypt fieldEncrypt) { + protected IEncryptor getEncryptor(FieldEncrypt fieldEncrypt) { Class encryptorClass = fieldEncrypt.encryptor(); // 使用预定义加/解密处理器 if (encryptorClass == IEncryptor.class) { @@ -75,4 +87,63 @@ public abstract class AbstractMyBatisInterceptor { // 使用自定义加/解密处理器 return SpringUtil.getBean(encryptorClass); } + + /** + * 获取加密参数 + * + * @param mappedStatement 映射语句 + * @return 获取加密参数 + */ + protected Map getEncryptParameters(MappedStatement mappedStatement) { + String mappedStatementId = mappedStatement.getId(); + return ENCRYPT_PARAM_CACHE.computeIfAbsent(mappedStatementId, key -> { + Method method = this.getMethod(mappedStatementId); + if (null == method) { + return Collections.emptyMap(); + } + Map encryptMap = new HashMap<>(); + Parameter[] parameters = method.getParameters(); + for (int i = 0; i < parameters.length; i++) { + Parameter parameter = parameters[i]; + FieldEncrypt fieldEncrypt = parameter.getAnnotation(FieldEncrypt.class); + if (null == fieldEncrypt) { + continue; + } + String parameterName = this.getParameterName(parameter); + encryptMap.put(parameterName, fieldEncrypt); + if (String.class.equals(parameter.getType())) { + encryptMap.put("param" + (i + 1), fieldEncrypt); + } + } + return encryptMap; + }); + } + + /** + * 获取映射方法 + * + * @param mappedStatementId 映射语句 ID + * @return 映射方法 + */ + private Method getMethod(String mappedStatementId) { + String className = CharSequenceUtil.subBefore(mappedStatementId, StringConstants.DOT, true); + String methodName = CharSequenceUtil.subAfter(mappedStatementId, StringConstants.DOT, true); + try { + Method[] methods = ReflectUtil.getMethods(Class.forName(className)); + return Stream.of(methods).filter(method -> method.getName().equals(methodName)).findFirst().orElse(null); + } catch (ClassNotFoundException e) { + throw new BaseException(e); + } + } + + /** + * 获取参数名称 + * + * @param parameter 参数 + * @return 参数名称 + */ + public String getParameterName(Parameter parameter) { + Param param = parameter.getAnnotation(Param.class); + return null != param ? param.value() : parameter.getName(); + } } \ No newline at end of file diff --git a/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/MyBatisEncryptInterceptor.java b/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/MyBatisEncryptInterceptor.java index 36a1640d..9cd07b7d 100644 --- a/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/MyBatisEncryptInterceptor.java +++ b/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/MyBatisEncryptInterceptor.java @@ -17,6 +17,7 @@ package top.continew.starter.security.crypto.core; import cn.hutool.core.text.CharSequenceUtil; +import cn.hutool.core.util.ClassUtil; import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.ReflectUtil; import com.baomidou.mybatisplus.core.conditions.AbstractWrapper; @@ -34,7 +35,6 @@ import top.continew.starter.security.crypto.autoconfigure.CryptoProperties; import top.continew.starter.security.crypto.encryptor.IEncryptor; import java.lang.reflect.Field; -import java.sql.SQLException; import java.util.*; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -66,20 +66,12 @@ public class MyBatisEncryptInterceptor extends AbstractMyBatisInterceptor implem return; } if (parameterObject instanceof Map parameterMap) { - Set set = new HashSet<>(parameterMap.values()); - for (Object parameter : set) { - if (parameter instanceof AbstractWrapper || parameter instanceof String) { - continue; - } - this.encryptEntity(super.getEncryptFields(parameter), parameter); - } + this.encryptQueryParameter(parameterMap, mappedStatement); } } @Override - public void beforeUpdate(Executor executor, - MappedStatement mappedStatement, - Object parameterObject) throws SQLException { + public void beforeUpdate(Executor executor, MappedStatement mappedStatement, Object parameterObject) { if (null == parameterObject) { return; } @@ -106,12 +98,39 @@ public class MyBatisEncryptInterceptor extends AbstractMyBatisInterceptor implem } // 别名带有 ew(针对 MP 的 UpdateWrapper、LambdaUpdateWrapper 等参数) if (parameterMap.containsKey(Constants.WRAPPER) && null != (parameter = parameterMap.get(Constants.WRAPPER))) { - this.encryptWrapper(parameter, mappedStatement); + this.encryptUpdateWrapper(parameter, mappedStatement); } } /** - * 处理 Wrapper 类型参数加密(针对 MP 的 UpdateWrapper、LambdaUpdateWrapper 等参数) + * 加密查询参数(针对 Map 类型参数) + * + * @param parameterMap 参数 + * @param mappedStatement 映射语句 + */ + private void encryptQueryParameter(Map parameterMap, MappedStatement mappedStatement) { + Map encryptParameterMap = super.getEncryptParameters(mappedStatement); + for (Map.Entry parameterEntrySet : parameterMap.entrySet()) { + String parameterName = parameterEntrySet.getKey(); + Object parameterValue = parameterEntrySet.getValue(); + if (null == parameterValue || ClassUtil.isBasicType(parameterValue + .getClass()) || parameterValue instanceof AbstractWrapper) { + continue; + } + if (parameterValue instanceof String str) { + FieldEncrypt fieldEncrypt = encryptParameterMap.get(parameterName); + if (null != fieldEncrypt) { + parameterMap.put(parameterName, this.doEncrypt(str, fieldEncrypt)); + } + } else { + // 实体参数 + this.encryptEntity(super.getEncryptFields(parameterValue), parameterValue); + } + } + } + + /** + * 处理 UpdateWrapper 类型参数加密(针对 MP 的 UpdateWrapper、LambdaUpdateWrapper 等参数) * * @param parameter Wrapper 参数 * @param mappedStatement 映射语句 @@ -120,7 +139,7 @@ public class MyBatisEncryptInterceptor extends AbstractMyBatisInterceptor implem * @author wangshaopeng@talkweb.com.cn(基于Mybatis-Plus拦截器实现MySQL数据加解密) */ - private void encryptWrapper(Object parameter, MappedStatement mappedStatement) { + private void encryptUpdateWrapper(Object parameter, MappedStatement mappedStatement) { if (parameter instanceof AbstractWrapper updateWrapper) { String sqlSet = updateWrapper.getSqlSet(); if (CharSequenceUtil.isBlank(sqlSet)) { @@ -146,12 +165,7 @@ public class MyBatisEncryptInterceptor extends AbstractMyBatisInterceptor implem if (matcher.matches()) { String valueKey = matcher.group(1); Object value = updateWrapper.getParamNameValuePairs().get(valueKey); - Object ciphertext; - try { - ciphertext = this.doEncrypt(value, fieldEncrypt); - } catch (Exception e) { - throw new BaseException(e); - } + Object ciphertext = this.doEncrypt(value, fieldEncrypt); updateWrapper.getParamNameValuePairs().put(valueKey, ciphertext); } } @@ -189,15 +203,18 @@ public class MyBatisEncryptInterceptor extends AbstractMyBatisInterceptor implem * * @param parameterValue 参数值 * @param fieldEncrypt 字段加密注解 - * @throws Exception / */ - private Object doEncrypt(Object parameterValue, FieldEncrypt fieldEncrypt) throws Exception { + private Object doEncrypt(Object parameterValue, FieldEncrypt fieldEncrypt) { if (null == parameterValue) { return null; } IEncryptor encryptor = super.getEncryptor(fieldEncrypt); // 优先获取自定义对称加密算法密钥,获取不到时再获取全局配置 String password = ObjectUtil.defaultIfBlank(fieldEncrypt.password(), properties.getPassword()); - return encryptor.encrypt(parameterValue.toString(), password, properties.getPublicKey()); + try { + return encryptor.encrypt(parameterValue.toString(), password, properties.getPublicKey()); + } catch (Exception e) { + throw new BaseException(e); + } } }