From 5c2ef3773e69808def66c9ee489ea9e1ef40045e Mon Sep 17 00:00:00 2001 From: Kevin Dost Date: Sat, 17 Oct 2020 20:10:52 +0200 Subject: [PATCH] Fix issue where sorting or filtering a collection fails on accesssing null members. --- Sieve/Extensions/OrderByDynamic.cs | 62 +++++++++++++++++++++--------- Sieve/Extensions/TypeExtentions.cs | 12 ++++++ Sieve/Services/SieveProcessor.cs | 44 ++++++++++++++------- SieveUnitTests/General.cs | 6 +-- 4 files changed, 89 insertions(+), 35 deletions(-) create mode 100644 Sieve/Extensions/TypeExtentions.cs diff --git a/Sieve/Extensions/OrderByDynamic.cs b/Sieve/Extensions/OrderByDynamic.cs index 5a885de..3fc8347 100644 --- a/Sieve/Extensions/OrderByDynamic.cs +++ b/Sieve/Extensions/OrderByDynamic.cs @@ -1,36 +1,62 @@ using System; using System.Linq; using System.Linq.Expressions; -using System.Reflection; namespace Sieve.Extensions { public static partial class LinqExtentions { - public static IQueryable OrderByDynamic(this IQueryable source, string fullPropertyName, PropertyInfo propertyInfo, - bool desc, bool useThenBy) + public static IQueryable OrderByDynamic( + this IQueryable source, + string fullPropertyName, + bool desc, + bool useThenBy) { - string command = desc ? - (useThenBy ? "ThenByDescending" : "OrderByDescending") : - (useThenBy ? "ThenBy" : "OrderBy"); - var type = typeof(TEntity); - var parameter = Expression.Parameter(type, "p"); + var lambda = GenerateLambdaWithSafeMemberAccess(fullPropertyName); - dynamic propertyValue = parameter; - if (fullPropertyName.Contains(".")) + var command = desc + ? (useThenBy ? "ThenByDescending" : "OrderByDescending") + : (useThenBy ? "ThenBy" : "OrderBy"); + + var resultExpression = Expression.Call( + typeof(Queryable), + command, + new Type[] { typeof(TEntity), lambda.ReturnType }, + source.Expression, + Expression.Quote(lambda)); + + return source.Provider.CreateQuery(resultExpression); + } + + private static Expression> GenerateLambdaWithSafeMemberAccess(string fullPropertyName) + { + var parameter = Expression.Parameter(typeof(TEntity), "e"); + Expression propertyValue = parameter; + Expression nullCheck = null; + + foreach (var name in fullPropertyName.Split('.')) { - var parts = fullPropertyName.Split('.'); - for (var i = 0; i < parts.Length - 1; i++) + propertyValue = Expression.PropertyOrField(propertyValue, name); + + if (propertyValue.Type.IsNullable()) { - propertyValue = Expression.PropertyOrField(propertyValue, parts[i]); + nullCheck = GenerateOrderNullCheckExpression(propertyValue, nullCheck); } } - var propertyAccess = Expression.MakeMemberAccess(propertyValue, propertyInfo); - var orderByExpression = Expression.Lambda(propertyAccess, parameter); - var resultExpression = Expression.Call(typeof(Queryable), command, new Type[] { type, propertyInfo.PropertyType }, - source.Expression, Expression.Quote(orderByExpression)); - return source.Provider.CreateQuery(resultExpression); + var expression = nullCheck == null + ? propertyValue + : Expression.Condition(nullCheck, Expression.Default(propertyValue.Type), propertyValue); + + var converted = Expression.Convert(expression, typeof(object)); + return Expression.Lambda>(converted, parameter); + } + + private static Expression GenerateOrderNullCheckExpression(Expression propertyValue, Expression nullCheckExpression) + { + return nullCheckExpression == null + ? Expression.Equal(propertyValue, Expression.Default(propertyValue.Type)) + : Expression.OrElse(nullCheckExpression, Expression.Equal(propertyValue, Expression.Default(propertyValue.Type))); } } } diff --git a/Sieve/Extensions/TypeExtentions.cs b/Sieve/Extensions/TypeExtentions.cs new file mode 100644 index 0000000..806cf56 --- /dev/null +++ b/Sieve/Extensions/TypeExtentions.cs @@ -0,0 +1,12 @@ +using System; + +namespace Sieve.Extensions +{ + public static partial class TypeExtentions + { + public static bool IsNullable(this Type type) + { + return !type.IsValueType || Nullable.GetUnderlyingType(type) != null; + } + } +} diff --git a/Sieve/Services/SieveProcessor.cs b/Sieve/Services/SieveProcessor.cs index 5913960..c57bc4d 100644 --- a/Sieve/Services/SieveProcessor.cs +++ b/Sieve/Services/SieveProcessor.cs @@ -170,21 +170,27 @@ namespace Sieve.Services } Expression outerExpression = null; - var parameterExpression = Expression.Parameter(typeof(TEntity), "e"); + var parameter = Expression.Parameter(typeof(TEntity), "e"); foreach (var filterTerm in model.GetFiltersParsed()) { Expression innerExpression = null; foreach (var filterTermName in filterTerm.Names) { - var (fullName, property) = GetSieveProperty(false, true, filterTermName); + var (fullPropertyName, property) = GetSieveProperty(false, true, filterTermName); if (property != null) { var converter = TypeDescriptor.GetConverter(property.PropertyType); + Expression propertyValue = parameter; + Expression nullCheck = null; - dynamic propertyValue = parameterExpression; - foreach (var part in fullName.Split('.')) + foreach (var name in fullPropertyName.Split('.')) { - propertyValue = Expression.PropertyOrField(propertyValue, part); + propertyValue = Expression.PropertyOrField(propertyValue, name); + + if (propertyValue.Type.IsNullable()) + { + nullCheck = GenerateFilterNullCheckExpression(propertyValue, nullCheck); + } } if (filterTerm.Values == null) continue; @@ -217,6 +223,11 @@ namespace Sieve.Services expression = Expression.Not(expression); } + if (nullCheck != null) + { + expression = Expression.AndAlso(nullCheck, expression); + } + if (innerExpression == null) { innerExpression = expression; @@ -251,7 +262,14 @@ namespace Sieve.Services } return outerExpression == null ? result - : result.Where(Expression.Lambda>(outerExpression, parameterExpression)); + : result.Where(Expression.Lambda>(outerExpression, parameter)); + } + + private static Expression GenerateFilterNullCheckExpression(Expression propertyValue, Expression nullCheckExpression) + { + return nullCheckExpression == null + ? Expression.NotEqual(propertyValue, Expression.Default(propertyValue.Type)) + : Expression.AndAlso(nullCheckExpression, Expression.NotEqual(propertyValue, Expression.Default(propertyValue.Type))); } private static Expression GetExpression(TFilterTerm filterTerm, dynamic filterValue, dynamic propertyValue) @@ -311,7 +329,7 @@ namespace Sieve.Services if (property != null) { - result = result.OrderByDynamic(fullName, property, sortTerm.Descending, useThenBy); + result = result.OrderByDynamic(fullName, sortTerm.Descending, useThenBy); } else { @@ -373,12 +391,12 @@ namespace Sieve.Services bool isCaseSensitive) { return Array.Find(typeof(TEntity).GetProperties(), p => - { - return p.GetCustomAttribute(typeof(SieveAttribute)) is SieveAttribute sieveAttribute - && (canSortRequired ? sieveAttribute.CanSort : true) - && (canFilterRequired ? sieveAttribute.CanFilter : true) - && ((sieveAttribute.Name ?? p.Name).Equals(name, isCaseSensitive ? StringComparison.Ordinal : StringComparison.OrdinalIgnoreCase)); - }); + { + return p.GetCustomAttribute(typeof(SieveAttribute)) is SieveAttribute sieveAttribute + && (!canSortRequired || sieveAttribute.CanSort) + && (!canFilterRequired || sieveAttribute.CanFilter) + && (sieveAttribute.Name ?? p.Name).Equals(name, isCaseSensitive ? StringComparison.Ordinal : StringComparison.OrdinalIgnoreCase); + }); } private IQueryable ApplyCustomMethod(IQueryable result, string name, object parent, object[] parameters, object[] optionalParameters = null) diff --git a/SieveUnitTests/General.cs b/SieveUnitTests/General.cs index b9dfa48..1c3f279 100644 --- a/SieveUnitTests/General.cs +++ b/SieveUnitTests/General.cs @@ -31,7 +31,6 @@ namespace SieveUnitTests LikeCount = 100, IsDraft = true, CategoryId = null, - TopComment = new Comment { Id = 0, Text = "A1" }, FeaturedComment = new Comment { Id = 4, Text = "A2" } }, new Post() { @@ -57,7 +56,7 @@ namespace SieveUnitTests LikeCount = 3, IsDraft = true, CategoryId = 2, - TopComment = new Comment { Id = 1, Text = "D1" }, + TopComment = new Comment { Id = 1 }, FeaturedComment = new Comment { Id = 7, Text = "D2" } }, }.AsQueryable(); @@ -388,11 +387,10 @@ namespace SieveUnitTests }; var result = _processor.Apply(model, _posts); - Assert.AreEqual(3, result.Count()); + Assert.AreEqual(2, result.Count()); var posts = result.ToList(); Assert.IsTrue(posts[0].TopComment.Text.Contains("B")); Assert.IsTrue(posts[1].TopComment.Text.Contains("C")); - Assert.IsTrue(posts[2].TopComment.Text.Contains("D")); } [TestMethod]