mirror of
				https://github.com/continew-org/continew-starter.git
				synced 2025-10-30 23:00:11 +08:00 
			
		
		
		
	fix(security/crypto): 修复处理 MP Wrapper 时 无法加密的情况 (#4)
This commit is contained in:
		| @@ -24,5 +24,6 @@ | |||||||
|             <groupId>com.baomidou</groupId> |             <groupId>com.baomidou</groupId> | ||||||
|             <artifactId>mybatis-plus-core</artifactId> |             <artifactId>mybatis-plus-core</artifactId> | ||||||
|         </dependency> |         </dependency> | ||||||
|  |          | ||||||
|     </dependencies> |     </dependencies> | ||||||
| </project> | </project> | ||||||
| @@ -53,7 +53,19 @@ public abstract class AbstractMyBatisInterceptor implements Interceptor { | |||||||
|      * @return 加密参数 |      * @return 加密参数 | ||||||
|      */ |      */ | ||||||
|     public Map<String, FieldEncrypt> getEncryptParams(String mappedStatementId) { |     public Map<String, FieldEncrypt> getEncryptParams(String mappedStatementId) { | ||||||
|         return ENCRYPT_PARAM_CACHE.computeIfAbsent(mappedStatementId, this::getEncryptParamsNoCached); |         return getEncryptParams(mappedStatementId, null); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     /** | ||||||
|  |      * 获取加密参数 | ||||||
|  |      * | ||||||
|  |      * @param mappedStatementId 映射语句 ID | ||||||
|  |      * @param parameterIndex    参数数量 | ||||||
|  |      * @return 加密参数 | ||||||
|  |      */ | ||||||
|  |     public Map<String, FieldEncrypt> getEncryptParams(String mappedStatementId, Integer parameterIndex) { | ||||||
|  |         return ENCRYPT_PARAM_CACHE | ||||||
|  |             .computeIfAbsent(mappedStatementId, it -> getEncryptParamsNoCached(mappedStatementId, parameterIndex)); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     /** |     /** | ||||||
| @@ -105,9 +117,10 @@ public abstract class AbstractMyBatisInterceptor implements Interceptor { | |||||||
|      * 获取参数列表(无缓存) |      * 获取参数列表(无缓存) | ||||||
|      * |      * | ||||||
|      * @param mappedStatementId 映射语句 ID |      * @param mappedStatementId 映射语句 ID | ||||||
|  |      * @param parameterIndex    参数数量 | ||||||
|      * @return 参数列表 |      * @return 参数列表 | ||||||
|      */ |      */ | ||||||
|     private Map<String, FieldEncrypt> getEncryptParamsNoCached(String mappedStatementId) { |     private Map<String, FieldEncrypt> getEncryptParamsNoCached(String mappedStatementId, Integer parameterIndex) { | ||||||
|         try { |         try { | ||||||
|             String className = CharSequenceUtil.subBefore(mappedStatementId, StringConstants.DOT, true); |             String className = CharSequenceUtil.subBefore(mappedStatementId, StringConstants.DOT, true); | ||||||
|             String wrapperMethodName = CharSequenceUtil.subAfter(mappedStatementId, StringConstants.DOT, true); |             String wrapperMethodName = CharSequenceUtil.subAfter(mappedStatementId, StringConstants.DOT, true); | ||||||
| @@ -117,8 +130,12 @@ public abstract class AbstractMyBatisInterceptor implements Interceptor { | |||||||
|                 .map(suffix -> wrapperMethodName.substring(0, wrapperMethodName.length() - suffix.length())) |                 .map(suffix -> wrapperMethodName.substring(0, wrapperMethodName.length() - suffix.length())) | ||||||
|                 .orElse(wrapperMethodName); |                 .orElse(wrapperMethodName); | ||||||
|             // 获取真实方法 |             // 获取真实方法 | ||||||
|             Optional<Method> methodOptional = Arrays.stream(ReflectUtil.getMethods(Class |             Optional<Method> methodOptional = Arrays.stream(ReflectUtil.getMethods(Class.forName(className), m -> { | ||||||
|                 .forName(className), m -> Objects.equals(m.getName(), methodName))).findFirst(); |                 if (Objects.nonNull(parameterIndex)) { | ||||||
|  |                     return Objects.equals(m.getName(), methodName) && m.getParameterCount() == parameterIndex; | ||||||
|  |                 } | ||||||
|  |                 return Objects.equals(m.getName(), methodName); | ||||||
|  |             })).findFirst(); | ||||||
|             if (methodOptional.isEmpty()) { |             if (methodOptional.isEmpty()) { | ||||||
|                 return Collections.emptyMap(); |                 return Collections.emptyMap(); | ||||||
|             } |             } | ||||||
| @@ -136,6 +153,8 @@ public abstract class AbstractMyBatisInterceptor implements Interceptor { | |||||||
|                     } |                     } | ||||||
|                 } else if (parameterName.startsWith(Constants.ENTITY)) { |                 } else if (parameterName.startsWith(Constants.ENTITY)) { | ||||||
|                     map.put(parameterName, null); |                     map.put(parameterName, null); | ||||||
|  |                 } else if (parameterName.startsWith(Constants.WRAPPER)) { | ||||||
|  |                     map.put(parameterName, null); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|             return map; |             return map; | ||||||
|   | |||||||
| @@ -16,9 +16,17 @@ | |||||||
|  |  | ||||||
| package top.continew.starter.security.crypto.core; | package top.continew.starter.security.crypto.core; | ||||||
|  |  | ||||||
|  | import cn.hutool.core.text.CharSequenceUtil; | ||||||
| import cn.hutool.core.util.ObjectUtil; | import cn.hutool.core.util.ObjectUtil; | ||||||
| import cn.hutool.core.util.ReflectUtil; | import cn.hutool.core.util.ReflectUtil; | ||||||
|  | import com.baomidou.mybatisplus.core.conditions.AbstractWrapper; | ||||||
|  | import com.baomidou.mybatisplus.core.conditions.Wrapper; | ||||||
|  | import com.baomidou.mybatisplus.core.mapper.BaseMapper; | ||||||
|  | import com.baomidou.mybatisplus.core.metadata.TableFieldInfo; | ||||||
|  | import com.baomidou.mybatisplus.core.metadata.TableInfo; | ||||||
|  | import com.baomidou.mybatisplus.core.metadata.TableInfoHelper; | ||||||
| import com.baomidou.mybatisplus.core.toolkit.Constants; | import com.baomidou.mybatisplus.core.toolkit.Constants; | ||||||
|  | import com.baomidou.mybatisplus.core.toolkit.StringUtils; | ||||||
| import org.apache.ibatis.cache.CacheKey; | import org.apache.ibatis.cache.CacheKey; | ||||||
| import org.apache.ibatis.executor.Executor; | import org.apache.ibatis.executor.Executor; | ||||||
| import org.apache.ibatis.mapping.BoundSql; | import org.apache.ibatis.mapping.BoundSql; | ||||||
| @@ -28,11 +36,14 @@ import org.apache.ibatis.plugin.*; | |||||||
| import org.apache.ibatis.session.ResultHandler; | import org.apache.ibatis.session.ResultHandler; | ||||||
| import org.apache.ibatis.session.RowBounds; | import org.apache.ibatis.session.RowBounds; | ||||||
| import org.apache.ibatis.type.SimpleTypeRegistry; | import org.apache.ibatis.type.SimpleTypeRegistry; | ||||||
|  | import top.continew.starter.core.constant.StringConstants; | ||||||
| import top.continew.starter.security.crypto.annotation.FieldEncrypt; | import top.continew.starter.security.crypto.annotation.FieldEncrypt; | ||||||
| import top.continew.starter.security.crypto.autoconfigure.CryptoProperties; | import top.continew.starter.security.crypto.autoconfigure.CryptoProperties; | ||||||
| import top.continew.starter.security.crypto.encryptor.IEncryptor; | import top.continew.starter.security.crypto.encryptor.IEncryptor; | ||||||
|  |  | ||||||
| import java.lang.reflect.Field; | import java.lang.reflect.Field; | ||||||
|  | import java.lang.reflect.ParameterizedType; | ||||||
|  | import java.lang.reflect.Type; | ||||||
| import java.util.*; | import java.util.*; | ||||||
|  |  | ||||||
| /** | /** | ||||||
| @@ -99,13 +110,18 @@ public class MyBatisEncryptInterceptor extends AbstractMyBatisInterceptor { | |||||||
|      * @throws Exception / |      * @throws Exception / | ||||||
|      */ |      */ | ||||||
|     private void encryptMap(HashMap<String, Object> parameterMap, MappedStatement mappedStatement) throws Exception { |     private void encryptMap(HashMap<String, Object> parameterMap, MappedStatement mappedStatement) throws Exception { | ||||||
|         Map<String, FieldEncrypt> encryptParamMap = super.getEncryptParams(mappedStatement.getId()); |         Map<String, FieldEncrypt> encryptParamMap = super.getEncryptParams(mappedStatement.getId(), parameterMap | ||||||
|  |             .isEmpty() ? null : parameterMap.size() / 2); | ||||||
|         for (Map.Entry<String, FieldEncrypt> encryptParamEntry : encryptParamMap.entrySet()) { |         for (Map.Entry<String, FieldEncrypt> encryptParamEntry : encryptParamMap.entrySet()) { | ||||||
|             String parameterName = encryptParamEntry.getKey(); |             String parameterName = encryptParamEntry.getKey(); | ||||||
|             if (parameterName.startsWith(Constants.ENTITY)) { |             if (parameterName.startsWith(Constants.ENTITY)) { | ||||||
|                 // 兼容 MyBatis Plus 封装的 update 相关方法,updateById、update |                 // 兼容 MyBatis Plus 封装的 update 相关方法,updateById、update | ||||||
|                 Object entity = parameterMap.getOrDefault(parameterName, null); |                 Object entity = parameterMap.getOrDefault(parameterName, null); | ||||||
|                 this.doEncrypt(this.getEncryptFields(entity), entity); |                 this.doEncrypt(this.getEncryptFields(entity), entity); | ||||||
|  |             } else if (parameterName.startsWith(Constants.WRAPPER)) { | ||||||
|  |                 Wrapper wrapper = (Wrapper)parameterMap.getOrDefault(parameterName, null); | ||||||
|  |                 // 处理 wrapper 的情况 | ||||||
|  |                 handleWrapperEncrypt(wrapper, mappedStatement); | ||||||
|             } else { |             } else { | ||||||
|                 FieldEncrypt fieldEncrypt = encryptParamEntry.getValue(); |                 FieldEncrypt fieldEncrypt = encryptParamEntry.getValue(); | ||||||
|                 parameterMap.put(parameterName, this.doEncrypt(parameterMap.get(parameterName), fieldEncrypt)); |                 parameterMap.put(parameterName, this.doEncrypt(parameterMap.get(parameterName), fieldEncrypt)); | ||||||
| @@ -148,4 +164,83 @@ public class MyBatisEncryptInterceptor extends AbstractMyBatisInterceptor { | |||||||
|             ReflectUtil.setFieldValue(entity, field, ciphertext); |             ReflectUtil.setFieldValue(entity, field, ciphertext); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     /** | ||||||
|  |      * 处理 wrapper 的加密情况 | ||||||
|  |      * | ||||||
|  |      * @param wrapper         wrapper 对象 | ||||||
|  |      * @param mappedStatement 映射语句 | ||||||
|  |      * @throws Exception / | ||||||
|  |      */ | ||||||
|  |     private void handleWrapperEncrypt(Wrapper wrapper, MappedStatement mappedStatement) throws Exception { | ||||||
|  |         if (wrapper instanceof AbstractWrapper abstractWrapper) { | ||||||
|  |             String sqlSet = abstractWrapper.getSqlSet(); | ||||||
|  |             if (StringUtils.isEmpty(sqlSet)) { | ||||||
|  |                 return; | ||||||
|  |             } | ||||||
|  |             String className = CharSequenceUtil.subBefore(mappedStatement.getId(), StringConstants.DOT, true); | ||||||
|  |             Class<?> mapperClass = Class.forName(className); | ||||||
|  |             Optional<Class> baseMapperGenerics = getDoByMapperClass(mapperClass, Optional.empty()); | ||||||
|  |             // 获取不到泛型对象 则不进行下面的逻辑 | ||||||
|  |             if (baseMapperGenerics.isEmpty()) { | ||||||
|  |                 return; | ||||||
|  |             } | ||||||
|  |             TableInfo tableInfo = TableInfoHelper.getTableInfo(baseMapperGenerics.get()); | ||||||
|  |             List<TableFieldInfo> fieldList = tableInfo.getFieldList(); | ||||||
|  |             // 将 name=#{ew.paramNameValuePairs.xxx},age=#{ew.paramNameValuePairs.xxx} 切出来 | ||||||
|  |             for (String sqlFragment : sqlSet.split(Constants.COMMA)) { | ||||||
|  |                 String columnName = sqlFragment.split(Constants.EQUALS)[0]; | ||||||
|  |                 // 截取其中的 xxx 字符 :#{ew.paramNameValuePairs.xxx} | ||||||
|  |                 String paramNameVal = sqlFragment.split(Constants.EQUALS)[1].substring(25, sqlFragment | ||||||
|  |                     .split(Constants.EQUALS)[1].length() - 1); | ||||||
|  |                 Optional<TableFieldInfo> fieldInfo = fieldList.stream() | ||||||
|  |                     .filter(f -> f.getColumn().equals(columnName)) | ||||||
|  |                     .findAny(); | ||||||
|  |                 if (fieldInfo.isPresent()) { | ||||||
|  |                     TableFieldInfo tableFieldInfo = fieldInfo.get(); | ||||||
|  |                     FieldEncrypt fieldEncrypt = tableFieldInfo.getField().getAnnotation(FieldEncrypt.class); | ||||||
|  |                     if (fieldEncrypt != null) { | ||||||
|  |                         Map<String, Object> paramNameValuePairs = abstractWrapper.getParamNameValuePairs(); | ||||||
|  |                         Object o = paramNameValuePairs.get(paramNameVal); | ||||||
|  |                         paramNameValuePairs.put(paramNameVal, this.doEncrypt(o, fieldEncrypt)); | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     /** | ||||||
|  |      * 从 Mapper 获取泛型 | ||||||
|  |      * | ||||||
|  |      * @param mapperClass      mapper class | ||||||
|  |      * @param tempResult 临时存储的泛型对象 | ||||||
|  |      * @return domain 对象 | ||||||
|  |      */ | ||||||
|  |     private static Optional<Class> getDoByMapperClass(Class<?> mapperClass, Optional<Class> tempResult) { | ||||||
|  |         Type[] genericInterfaces = mapperClass.getGenericInterfaces(); | ||||||
|  |         Optional<Class> result = tempResult; | ||||||
|  |         for (Type genericInterface : genericInterfaces) { | ||||||
|  |             if (genericInterface instanceof ParameterizedType parameterizedType) { | ||||||
|  |                 Type rawType = parameterizedType.getRawType(); | ||||||
|  |                 Type[] actualTypeArguments = parameterizedType.getActualTypeArguments(); | ||||||
|  |                 // 如果匹配上 BaseMapper 且泛型参数是 Class 类型,则直接返回 | ||||||
|  |                 if (rawType.equals(BaseMapper.class)) { | ||||||
|  |                     return actualTypeArguments[0] instanceof Class ? | ||||||
|  |                             Optional.of((Class) actualTypeArguments[0]) : result; | ||||||
|  |                 } else if (rawType instanceof Class interfaceClass) { | ||||||
|  |                     // 如果泛型参数是 Class 类型,则传递给递归调用 | ||||||
|  |                     if (actualTypeArguments[0] instanceof Class tempResultClass) { | ||||||
|  |                         result = Optional.of(tempResultClass); | ||||||
|  |                     } | ||||||
|  |                     // 递归调用,继续查找 | ||||||
|  |                     Optional<Class> innerResult = getDoByMapperClass(interfaceClass, result); | ||||||
|  |                     if (innerResult.isPresent()) { | ||||||
|  |                         return innerResult; | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         // 如果没有找到,返回传递进来的 tempResult | ||||||
|  |         return Optional.empty(); | ||||||
|  |     } | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 cary
					cary