C# Linq源码分析之Take(四)

概要

本文主要对Take的优化方法进行源码分析,分析Take在配合Select,Where等常用的Linq扩展方法使用时候,如何实现优化处理。

本文涉及到Select, Where和Take和三个方法的源码分析,其中Select, Where, Take更详尽的源码分析,请参考我之前写的文章。

源码分析

我们之前对Take的源码分析,主要是真对数据序列对象直接调用的Take方法的情况。本文介绍的Take优化方法,主要是针对多个Linq方法配合使用的情况,例如xx.Where.Select.Take或者xx.Select.Take的情况。

Take的优化方法,是定义在Take.SpeedOpt.cs文件中,体现在下面的TakeIterator方法中,Take方法对TakeIterator方法的具体调用方式,请参考C# Linq源码分析之Take方法

csharp 复制代码
private static IEnumerable<TSource> TakeIterator<TSource>(IEnumerable<TSource> source, int count)
 {
     Debug.Assert(source != null);
     Debug.Assert(count > 0);

     return
         source is IPartition<TSource> partition ? partition.Take(count) :
         source is IList<TSource> sourceList ? new ListPartition<TSource>(sourceList, 0, count - 1) :
         new EnumerablePartition<TSource>(source, 0, count - 1);
 }

该优化方法的基本逻辑如下:

  1. 如果source序列实现了IPartition接口,则调用IPartition中的Take方法;
  2. 如果source序列实现了IList接口,则返回ListPartition对象;
  3. 返回EnumerablePartition对象。

案例分析

Select.Take

该场景模拟我们显示中将EF中与数据库关联的对象,转换成Web前端需要的对象,并分页的情况。

将Student对象中的学生姓名和考试成绩取出后,并分页,Student类的定义和初始化请见附录。

csharp 复制代码
 studentList.Select(x => new {
                x.Name, x.MathResult
            }).take(3).ToList();

执行流程如上图所示:

  1. 进入Select方法后返回一个SelectListIterator对象,该对象存储了source序列数据和selector;
  2. 进入Take方法后, 因为SelectListIterator实现了IPartition接口,因此可以调用自己的Take方法,使用source,selector和count(Take方法的参数)实例化SelectListPartitionIterator对象;
csharp 复制代码
  public IPartition<TResult> Take(int count)
{
    Debug.Assert(count > 0);
    int maxIndex = _minIndexInclusive + count - 1;
    return (uint)maxIndex >= (uint)_maxIndexInclusive ? this : new SelectListPartitionIterator<TSource, TResult>(_source, _selector, _minIndexInclusive, maxIndex);
}
  1. 进入ToList方法后,根据下面的代码,因为SelectListPartitionIterator对象实现了IPartition接口,而IPartition又继承了IIListProvider接口,所以listProvider.ToList()被执行。
csharp 复制代码
   public static List<TSource> ToList<TSource>(this IEnumerable<TSource> source)
        {
if (source == null)
 {
     ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
 }

 return source is IIListProvider<TSource> listProvider ? listProvider.ToList() : new List<TSource>(source);
}

listProvider.ToList() 被执行,即SelectListPartitionIterator对象内定义的ToList()方法被执行,该ToList方法可以将

Selecxt的投影操作,Take的取值操作同时执行,代码如下:

csharp 复制代码
 public List<TResult> ToList()
{
    int count = Count;
    if (count == 0)
    {
        return new List<TResult>();
    }

    List<TResult> list = new List<TResult>(count);
    int end = _minIndexInclusive + count;
    for (int i = _minIndexInclusive; i != end; ++i)
    {
        list.Add(_selector(_source[i]));
    }

    return list;
}
  1. 如果Take的取值为0,直接返回空List,不会进行Select操作;
  2. 按照Take中指定的Count,取到相应的元素,再进行投影操作;
  3. 返回操作结果。

这样,源List序列source,Linq只遍历和一次,就同时完成了投影和过滤两个操作,实现了优化。

附录

Student类

csharp 复制代码
public class Student {
    public string Id { get; set; }
    public string Name { get; set; }
    public string Classroom { get; set; }
    public int MathResult { get; set; }
}

IIListProvider接口

csharp 复制代码
internal interface IIListProvider<TElement> : IEnumerable<TElement>
{
    TElement[] ToArray();
    List<TElement> ToList();
    int GetCount(bool onlyIfCheap);
}

IPartition接口

csharp 复制代码
internal interface IPartition<TElement> : IIListProvider<TElement>
{

    IPartition<TElement> Skip(int count);

    IPartition<TElement> Take(int count);

    TElement? TryGetElementAt(int index, out bool found);

    TElement? TryGetFirst(out bool found);

    TElement? TryGetLast(out bool found);
}

SelectListPartitionIterator类

csharp 复制代码
 private sealed class SelectListPartitionIterator<TSource, TResult> : Iterator<TResult>, IPartition<TResult>
{
    private readonly IList<TSource> _source;
    private readonly Func<TSource, TResult> _selector;
    private readonly int _minIndexInclusive;
    private readonly int _maxIndexInclusive;

    public SelectListPartitionIterator(IList<TSource> source, Func<TSource, TResult> selector, int minIndexInclusive, int maxIndexInclusive)
    {
        Debug.Assert(source != null);
        Debug.Assert(selector != null);
        Debug.Assert(minIndexInclusive >= 0);
        Debug.Assert(minIndexInclusive <= maxIndexInclusive);
        _source = source;
        _selector = selector;
        _minIndexInclusive = minIndexInclusive;
        _maxIndexInclusive = maxIndexInclusive;
    }

    public override Iterator<TResult> Clone() =>
        new SelectListPartitionIterator<TSource, TResult>(_source, _selector, _minIndexInclusive, _maxIndexInclusive);

    public override bool MoveNext()
    {
        // _state - 1 represents the zero-based index into the list.
        // Having a separate field for the index would be more readable. However, we save it
        // into _state with a bias to minimize field size of the iterator.
        int index = _state - 1;
        if (unchecked((uint)index <= (uint)(_maxIndexInclusive - _minIndexInclusive) && index < _source.Count - _minIndexInclusive))
        {
            _current = _selector(_source[_minIndexInclusive + index]);
            ++_state;
            return true;
        }

        Dispose();
        return false;
    }

    public override IEnumerable<TResult2> Select<TResult2>(Func<TResult, TResult2> selector) =>
        new SelectListPartitionIterator<TSource, TResult2>(_source, CombineSelectors(_selector, selector), _minIndexInclusive, _maxIndexInclusive);

    public IPartition<TResult> Skip(int count)
    {
        Debug.Assert(count > 0);
        int minIndex = _minIndexInclusive + count;
        return (uint)minIndex > (uint)_maxIndexInclusive ? EmptyPartition<TResult>.Instance : new SelectListPartitionIterator<TSource, TResult>(_source, _selector, minIndex, _maxIndexInclusive);
    }

    public IPartition<TResult> Take(int count)
    {
        Debug.Assert(count > 0);
        int maxIndex = _minIndexInclusive + count - 1;
        return (uint)maxIndex >= (uint)_maxIndexInclusive ? this : new SelectListPartitionIterator<TSource, TResult>(_source, _selector, _minIndexInclusive, maxIndex);
    }

    public TResult? TryGetElementAt(int index, out bool found)
    {
        if ((uint)index <= (uint)(_maxIndexInclusive - _minIndexInclusive) && index < _source.Count - _minIndexInclusive)
        {
            found = true;
            return _selector(_source[_minIndexInclusive + index]);
        }

        found = false;
        return default;
    }

    public TResult? TryGetFirst(out bool found)
    {
        if (_source.Count > _minIndexInclusive)
        {
            found = true;
            return _selector(_source[_minIndexInclusive]);
        }

        found = false;
        return default;
    }

    public TResult? TryGetLast(out bool found)
    {
        int lastIndex = _source.Count - 1;
        if (lastIndex >= _minIndexInclusive)
        {
            found = true;
            return _selector(_source[Math.Min(lastIndex, _maxIndexInclusive)]);
        }

        found = false;
        return default;
    }

    private int Count
    {
        get
        {
            int count = _source.Count;
            if (count <= _minIndexInclusive)
            {
                return 0;
            }

            return Math.Min(count - 1, _maxIndexInclusive) - _minIndexInclusive + 1;
        }
    }

    public TResult[] ToArray()
    {
        int count = Count;
        if (count == 0)
        {
            return Array.Empty<TResult>();
        }

        TResult[] array = new TResult[count];
        for (int i = 0, curIdx = _minIndexInclusive; i != array.Length; ++i, ++curIdx)
        {
            array[i] = _selector(_source[curIdx]);
        }

        return array;
    }

    public List<TResult> ToList()
    {
        int count = Count;
        if (count == 0)
        {
            return new List<TResult>();
        }

        List<TResult> list = new List<TResult>(count);
        int end = _minIndexInclusive + count;
        for (int i = _minIndexInclusive; i != end; ++i)
        {
            list.Add(_selector(_source[i]));
        }

        return list;
    }

    public int GetCount(bool onlyIfCheap)
    {
        // In case someone uses Count() to force evaluation of
        // the selector, run it provided `onlyIfCheap` is false.

        int count = Count;

        if (!onlyIfCheap)
        {
            int end = _minIndexInclusive + count;
            for (int i = _minIndexInclusive; i != end; ++i)
            {
                _selector(_source[i]);
            }
        }

        return count;
    }
}

SelectListIterator类 Select.cs

csharp 复制代码
 private sealed partial class SelectListIterator<TSource, TResult> : Iterator<TResult>
 {
     private readonly List<TSource> _source;
     private readonly Func<TSource, TResult> _selector;
     private List<TSource>.Enumerator _enumerator;

     public SelectListIterator(List<TSource> source, Func<TSource, TResult> selector)
     {
         Debug.Assert(source != null);
         Debug.Assert(selector != null);
         _source = source;
         _selector = selector;
     }

     private int CountForDebugger => _source.Count;

     public override Iterator<TResult> Clone() => new SelectListIterator<TSource, TResult>(_source, _selector);

     public override bool MoveNext()
     {
         switch (_state)
         {
             case 1:
                 _enumerator = _source.GetEnumerator();
                 _state = 2;
                 goto case 2;
             case 2:
                 if (_enumerator.MoveNext())
                 {
                     _current = _selector(_enumerator.Current);
                     return true;
                 }

                 Dispose();
                 break;
         }

         return false;
     }

     public override IEnumerable<TResult2> Select<TResult2>(Func<TResult, TResult2> selector) =>
         new SelectListIterator<TSource, TResult2>(_source, CombineSelectors(_selector, selector));
 }

Select.SpeedOpt.cs

csharp 复制代码
private sealed partial class SelectListIterator<TSource, TResult> : IPartition<TResult>
 {
      public TResult[] ToArray()
      {
          int count = _source.Count;
          if (count == 0)
          {
              return Array.Empty<TResult>();
          }

          var results = new TResult[count];
          for (int i = 0; i < results.Length; i++)
          {
              results[i] = _selector(_source[i]);
          }

          return results;
      }

      public List<TResult> ToList()
      {
          int count = _source.Count;
          var results = new List<TResult>(count);
          for (int i = 0; i < count; i++)
          {
              results.Add(_selector(_source[i]));
          }

          return results;
      }

      public int GetCount(bool onlyIfCheap)
      {
          // In case someone uses Count() to force evaluation of
          // the selector, run it provided `onlyIfCheap` is false.

          int count = _source.Count;

          if (!onlyIfCheap)
          {
              for (int i = 0; i < count; i++)
              {
                  _selector(_source[i]);
              }
          }

          return count;
      }

      public IPartition<TResult> Skip(int count)
      {
          Debug.Assert(count > 0);
          return new SelectListPartitionIterator<TSource, TResult>(_source, _selector, count, int.MaxValue);
      }

      public IPartition<TResult> Take(int count)
      {
          Debug.Assert(count > 0);
          return new SelectListPartitionIterator<TSource, TResult>(_source, _selector, 0, count - 1);
      }

      public TResult? TryGetElementAt(int index, out bool found)
      {
          if (unchecked((uint)index < (uint)_source.Count))
          {
              found = true;
              return _selector(_source[index]);
          }

          found = false;
          return default;
      }

      public TResult? TryGetFirst(out bool found)
      {
          if (_source.Count != 0)
          {
              found = true;
              return _selector(_source[0]);
          }

          found = false;
          return default;
      }

      public TResult? TryGetLast(out bool found)
      {
          int len = _source.Count;
          if (len != 0)
          {
              found = true;
              return _selector(_source[len - 1]);
          }

          found = false;
          return default;
      }
  }
相关推荐
Dovis(誓平步青云)33 分钟前
《QT学习第四篇:常见事件与UDP、TCP、文件系统、(锁、信号量、条件变量》
c语言·开发语言·汇编·qt
isyangli_blog9 小时前
OpenDayLight (Carbon 版本) 启动与组件安装
开发语言·php
vb2008119 小时前
FastAPI APIRouter
开发语言·python
Benszen9 小时前
KVM虚拟化解决方案
开发语言·perl
会编程的土豆9 小时前
Go 语言反射(Reflection)详解
开发语言·后端·golang
東雪木9 小时前
多线程与并发编程 专属复习笔记
java·开发语言·笔记·java面试
杨充10 小时前
1.3 浮点型数据设计灵魂
开发语言·python·算法
噜噜噜阿鲁~10 小时前
python学习笔记 | 11.3、面向对象高级编程-多重继承
java·开发语言
basketball61610 小时前
Go 语言从入门到进阶:4. 数组和MAP使用方法总结
开发语言·后端·golang
春生野草10 小时前
反射、Tomcat执行
java·开发语言