;;;
;;; Array operations.  This is superset of SRFI-25.
;;;
;;;  Copyright (c) 2002-2025  Shiro Kawai  <shiro@acm.org>
;;;  Copyright(C) 2004      by Alex Shinn  (foof@synthcode.com)
;;;
;;;  Permission to use, copy, modify, distribute this software and
;;;  accompanying documentation for any purpose is hereby granted,
;;;  provided that existing copyright notices are retained in all
;;;  copies and that this notice is included verbatim in all
;;;  distributions.
;;;  This software is provided as is, without express or implied
;;;  warranty.  In no circumstances the author(s) shall be liable
;;;  for any damages arising out of the use of this software.
;;;

;; Conceptually, an array is a backing storage and a procedure to
;; map n-dimensional indices to an index of the backing storage.

(define-module gauche.array
  (use scheme.list)
  (use gauche.uvector)
  (use gauche.collection)
  (use gauche.sequence)
  (export <array-meta> <array>
          <u8array> <s8array> <u16array> <s16array> <u32array> <s32array>
          <u64array> <s64array> <f16array> <f32array> <f64array>
          <c32array> <c64array> <c128array>
          array? make-array array-copy shape array array-rank
          array-start array-end array-ref array-set!
          share-array subarray array-equal?
          array-valid-index?  shape-valid-index?
          array-shape array-length array-size
          array-for-each-index shape-for-each
          array-for-each-index-by-dimension
          array-for-each array-every array-any
          tabulate-array array-retabulate!
          array-map array-map! array->vector array->list
          array->nested-list array->nested-vector
          make-u8array make-s8array make-u16array make-s16array
          make-u32array make-s32array make-u64array make-s64array
          make-f16array make-f32array make-f64array
          make-c32array make-c64array make-c128array
          u8array s8array u16array s16array u32array s32array
          u64array s64array f16array f32array f64array
          c32array c64array c128array
          array-concatenate array-transpose array-rotate-90
          array-flip array-flip!
          identity-array array-inverse determinant determinant! array-trace
          array-mul array-vector-mul vector-array-mul array-expt
          array-div-left array-div-right
          array-add-elements array-add-elements!
          array-sub-elements array-sub-elements!
          array-negate-elements array-negate-elements!
          array-mul-elements array-mul-elements!
          array-div-elements array-div-elements!
          array-reciprocate-elements array-reciprocate-elements!
          format-array/prefix format-array/content ; for pretty printer
          pretty-print-array
          ))
(select-module gauche.array)

(autoload "gauche/matrix"
  array-concatenate array-transpose array-rotate-90 array-flip array-flip!
  identity-array array-inverse determinant determinant! array-trace
  array-mul array-vector-mul vector-array-mul array-expt
  array-div-left array-div-right
  array-add-elements array-add-elements!
  array-sub-elements array-sub-elements!
  array-negate-elements array-negate-elements!
  array-mul-elements array-mul-elements!
  array-div-elements array-div-elements!
  array-reciprocate-elements array-reciprocate-elements!
  pretty-print-array)

(define-class <array-meta> (<class>)
  ((backing-storage-class :init-keyword :backing-storage-class)
   (tag :init-keyword :tag)             ;for external representation
   (signed :init-keyword :signed        ;metainfo used in matrix op
           :init-value #f)
   (integral :init-keyword :integral    ;ditto
             :init-value #f)
   (real :init-keyword :real            ;ditto
         :init-value #f)
   (element-size :init-keyword :element-size ;ditto
                 :init-value 0)
   ))

;; auto reader and writer for array-meta

(define-method initialize ((class <array-meta>) initargs)
  (next-method)
  (let ((name (class-name class)))
    (when name
      (define-reader-ctor name
        (^[sh . inits]
          (make-array-internal class (apply shape sh)
                               :init-list inits))))))

;; <array-base> is defined in the core
(define <array-base> (with-module gauche.internal <array-base>))

(define (start-vector-of base)
  (~ base'start-vector))
(define (end-vector-of base)
  (~ base'end-vector))

;;
;; Formatting array
;;

(define-method write-object ((self <array-base>) port)
  (let1 style
      (~ ((with-module gauche.internal %port-write-controls) port) 'array)
    (display (format-array/prefix self style) port)
    (write (format-array/content self style) port)))

(define (format-array/prefix array style)
  (ecase style
    [(compact dimensions) => (cut format-array/srfi-163-prefix array <>)]
    [(reader-ctor) "#,"]))

(define (format-array/content array style)
  ;; This does not "format" the content to the port; instead, it returns
  ;; a list form, which should be written to the port.
  (ecase style
    [(compact dimensions) (array->nested-list array)]
    [(reader-ctor) (list* (class-name (class-of array))
                          (array->list (array-shape array))
                          (array->list array))]))

;; #<rank>a style.  FMT can be 'compact or 'dimensions
;;  If compact,
;;    - If all dimensions start from 0, we omit dimensions altogether.
;;    - Otherwise, we show start index for every dimension
;;  If dimensions
;;    - We show start if it's not 0
;;    - We show length for all dimensions
(define (format-array/srfi-163-prefix array fmt)
  (define (dims full?)
    (with-output-to-string
      (^[] (dotimes [i (array-rank array)]
             (let ([s (array-start array i)]
                   [e (array-end array i)])
               (unless (and full? (= s 0)) (display #\@) (display s))
               (when full? (begin (display #\:) (display (- e s)))))))))

  (format "#~a~a~@[~a~]~a"
          (array-rank array)
          (array-tag (class-of array))
          (cond [(eq? fmt 'dimensions) (dims #t)]
                [(not (every (^i (zero? (array-start array i)))
                             (iota (array-rank array))))
                 (dims #f)]
                [else #f])
          (if (zero? (array-rank array)) " " "") ; need this for zero-dim array
          ))

(define-class <array> (<array-base>)
  ()
  :metaclass <array-meta>
  :backing-storage-class <vector>
  :tag 'a)

(define-class <u8array> (<array-base>)
  ()
  :metaclass <array-meta>
  :backing-storage-class <u8vector>
  :tag 'u8
  :element-size 8
  :integral #t
  :real #t)

(define-class <s8array> (<array-base>)
  ()
  :metaclass <array-meta>
  :backing-storage-class <s8vector>
  :tag 's8
  :element-size 8
  :signed #t
  :integral #t
  :real #t)

(define-class <u16array> (<array-base>)
  ()
  :metaclass <array-meta>
  :backing-storage-class <u16vector>
  :tag 'u16
  :element-size 16
  :integral #t
  :real #t)

(define-class <s16array> (<array-base>)
  ()
  :metaclass <array-meta>
  :backing-storage-class <s16vector>
  :tag 's16
  :element-size 16
  :signed #t
  :integral #t
  :real #t)

(define-class <u32array> (<array-base>)
  ()
  :metaclass <array-meta>
  :backing-storage-class <u32vector>
  :tag 'u32
  :element-size 32
  :integral #t
  :real #t)

(define-class <s32array> (<array-base>)
  ()
  :metaclass <array-meta>
  :backing-storage-class <s32vector>
  :tag 's32
  :element-size 32
  :signed #t
  :integral #t
  :real #t)

(define-class <u64array> (<array-base>)
  ()
  :metaclass <array-meta>
  :backing-storage-class <u64vector>
  :tag 'u64
  :element-size 64
  :integral #t
  :real #t)

(define-class <s64array> (<array-base>)
  ()
  :metaclass <array-meta>
  :backing-storage-class <s64vector>
  :tag 's64
  :element-size 64
  :signed #t
  :integral #t
  :real #t)

(define-class <f16array> (<array-base>)
  ()
  :metaclass <array-meta>
  :backing-storage-class <f16vector>
  :tag 'f16
  :element-size 16
  :signed #t
  :real #t)

(define-class <f32array> (<array-base>)
  ()
  :metaclass <array-meta>
  :backing-storage-class <f32vector>
  :tag 'f32
  :element-size 32
  :signed #t
  :real #t)

(define-class <f64array> (<array-base>)
  ()
  :metaclass <array-meta>
  :backing-storage-class <f64vector>
  :tag 'f64
  :element-size 64
  :signed #t
  :real #t)

(define-class <c32array> (<array-base>)
  ()
  :metaclass <array-meta>
  :backing-storage-class <c32vector>
  :tag 'c32
  :element-size 32)

(define-class <c64array> (<array-base>)
  ()
  :metaclass <array-meta>
  :backing-storage-class <c64vector>
  :tag 'c64
  :element-size 64)

(define-class <c128array> (<array-base>)
  ()
  :metaclass <array-meta>
  :backing-storage-class <c128vector>
  :tag 'c128
  :element-size 128)

;; Utility to inquire array-type attributes

(define (array-tag class) (~ class 'tag))
(define (non-numeric? class) (eq? class <array>))
(define (non-real? class) (not (~ class 'real)))
(define (non-integral? class) (not (~ class 'integral)))
(define (inexact-numeric? class) (and (not (non-numeric? class))
                                      (non-integral? class)))
(define (signed-integral? class) (~ class 'signed))
(define (element-size class) (~  class 'element-size))

;; internal
(define-inline (%xvector-copy v)
  (cond [(vector? v) (vector-copy v)]
        [(uvector? v) (uvector-copy v)]
        [else (error "Vector or uvector required, but got:" v)]))

(define (array-copy a)
  (assume-type a <array-base>)
  (make (class-of a)
    :start-vector (~ a'start-vector)
    :end-vector   (~ a'end-vector)
    :coefficient-vector (~ a'coefficient-vector)
    :backing-storage (%xvector-copy (~ a'backing-storage))))

;;-------------------------------------------------------------
;; Literal array reader
;; This is not exported API, but called from read.c when it sees #a(...) etc.
;; The reader has already read the optional RANK, then the subsequent character
;; which must be in #[asu].  That character is passed as TYPE-CHAR.

(define (%read-array-literal port rank type-char ctx)
  (define line (port-current-line port))
  (define chars `(,type-char))          ; for error message
  (define (save-char!)
    (let1 ch (read-char port)
      (unless (eof-object? ch) (push! chars ch))))
  (define (prefix) (list->string (reverse chars)))
  (define (err msg content)
    (errorf <read-error> :port port :line line
            (string-append msg ": #~a~a~@[~s~]")
            (if (< rank 0) "" rank) (prefix) content))
  (define (bad-prefix)
    ;; We read up to the delimiter so that the subsequent read won't be
    ;; tripped.
    (let loop ((ch (peek-char port)))
      (unless (or (eof-object? ch) (#[\s\(\[\{#\"'`,] ch))
        (save-char!) (loop (peek-char port))))
    (err "Invalid array literal prefix" #f))
  (define (make-type-tag nbits)
    (string->symbol (format "~c~d" (char-down-case type-char) nbits)))
  (define type-class
    (if (#[aA] type-char)
      <array>
      (case (read-digits #f)
        [(8)  (case type-char
                [(#\s #\S) <s8array>]
                [(#\u #\U) <u8array>]
                [else (bad-prefix)])]
        [(16) (case type-char
                [(#\f #\F) <f16array>]
                [(#\s #\S) <s16array>]
                [(#\u #\U) <u16array>]
                [else (bad-prefix)])]
        [(32) (case type-char
                [(#\c #\C) <c32array>]
                [(#\f #\F) <f32array>]
                [(#\s #\S) <s32array>]
                [(#\u #\U) <u32array>]
                [else (bad-prefix)])]
        [(64) (case type-char
                [(#\c #\C) <c64array>]
                [(#\f #\F) <f64array>]
                [(#\s #\S) <s64array>]
                [(#\u #\U) <u64array>]
                [else (bad-prefix)])]
        [(128) (case type-char
                [(#\c #\C) <c128array>]
                [else (bad-prefix)])]
        [else (bad-prefix)])))
  (define (read-dimensions r)
    (case (peek-char port)
      [(#\( #\space) (reverse r)]
      [(#\@) (save-char!) (read-start r)]
      [(#\:) (save-char!) (read-length 0 r)]
      [(#\" #\| #\[ #\{) (bad-prefix)]  ; not to trip subsequent read
      [else  (save-char!) (bad-prefix)]))
  (define (read-digits allow-sign?)
    (let loop ([ch (peek-char port)]
               [ds '()])
      (cond
       [(eof-object? ch) (bad-prefix)]
       [(#[0-9] ch)
        (save-char!)
        (loop (peek-char port) (cons ch ds))]
       [(and allow-sign? (#[+-] ch))
        (save-char!)
        (if (null? ds)
          (loop (peek-char port) (cons ch ds))
          (bad-prefix))]
       [else
        (if (null? ds)
          (bad-prefix)
          ($ string->number $ list->string $ reverse ds))])))
  (define (read-start r)
    (let* ([n (read-digits #t)]
           [ch (peek-char port)])
      (case ch
        [(#\:) (save-char!) (read-length n r)]
        [(#\@) (save-char!) (read-start `((,n #f) ,@r))]
        [(#\( #\space) (reverse `((,n #f) ,@r))]
        [(#\" #\| #\[ #\{) (bad-prefix)]  ; not to trip subsequent read
        [else  (save-char!) (bad-prefix)])))
  (define (read-length start r)
    (let* ([n (read-digits #t)]
           [ch (peek-char port)])
      (case ch
        [(#\:) (save-char!) (read-length 0 `((,start ,n) ,@r))]
        [(#\@) (save-char!) (read-start `((,start ,n) ,@r))]
        [(#\( #\space) (reverse `((,start ,n) ,@r))]
        [(#\" #\| #\[ #\{) (bad-prefix)]  ; not to trip subsequent read
        [else  (save-char!) (bad-prefix)])))
  (define (dim-check suggested content) ;returns #f when bad shape
    (let* ([start (if (pair? suggested) (car suggested) 0)]
           [actual (length content)]
           [len (or (and (pair? suggested) (cadr suggested)) actual)])
      (and (= actual len) `(,start ,len))))
  (define (dim-check-all dims depth contents) ;returns #f when bad shape
    (cond [(= depth rank) '()]
          [(null? dims)
           (if (list? contents)
             (dim-check-all '((0 #f)) depth contents) ;retry with default dims
             '())]
          [else
           (and-let* ([sh (dim-check (car dims) contents)]
                      [shs (fold (^[content shs]
                                   (and shs
                                        (dim-check-all shs (+ depth 1) content)))
                                 (cdr dims) contents)])
             (cons sh shs))]))
  (define (dim->shape dim-list)
    ;; ((start len) (start2 len2) ...)
    ;;   => (start end start2 end2 ...)
    (apply shape (append-map (^p `(,(car p) ,(+ (car p) (cadr p)))) dim-list)))
  (define (flatten contents depth rest)
    (if (< depth 0)
      (cons contents rest)
      (fold-right (cute flatten <> (- depth 1) <>) rest contents)))

  ;; Start parsing
  (define given-dims (read-dimensions '()))
  (define contents (read port))

  ;; If both dimensions and rank are given, they must match.
  (when (and (>= rank 0)
             (not (null? given-dims))
             (not (= rank (length given-dims))))
    (err "Array literal's rank and dimensions don't match" contents))
  (when (and (< rank 0)
             (not (null? given-dims)))
    ;; If rank isn't given but dimensions are, set the rank.
    (set! rank (length given-dims)))

  (let ([dim-list (dim-check-all given-dims 0 contents)])
    (unless dim-list
      ;; Given dimensions and the contents doesn't match.
      (err "Array literal has inconsistent shape" contents))
    (when (and (>= rank 0)
               (< (length dim-list) rank))
      ;; Deduced dimensions doesn't match the given rank.
      (err "Array literal has inconsistent rank" contents))

    ($ make-array-internal type-class (dim->shape dim-list)
       :init-list (flatten contents (- (length dim-list) 1) '()))))

;;-------------------------------------------------------------
;; Affine mapper
;;
;;  Given begin-vector Vb = #s32(b0 b1 ... bN)
;;        end-vector   Ve = #s32(e0 e1 ... eN)
;;        where b0 <= e0, ..., bN <= eN
;;  Returns a procedure,
;;    which calculates 1-dimentional offset off to the backing storage,
;;          from given index vector #s32(i0 i1 ... iN)
;;
;;  Pre-calculation:
;;    sizes          s0 = e0 - b0, s1 = e1 - b1 ...
;;    coefficients   c0   = s1 * s2 * ... * sN
;;                   c1   = s2 * ... * sN
;;                   cN-1 = sN
;;                   cN   = 1
;;
;;  Mapping
;;   off = c0*(i0-b0) + c1*(i1-b1) + .. + cN*(iN-bN)
;;
;;

(define (zero-vector? vec)
  (not (s32vector-range-check vec 0 0)))

(define (coefficient-vector Vb Ve)
  (let* ([Vs      (s32vector-sub Ve Vb)]
         [vcl     (fold-right (^[sN l] (cons (* sN (car l)) l))
                              '(1)
                              (s32vector->list Vs))])
    (coerce-to <s32vector> (cdr vcl))))

;; shape index tests

(define-method shape-valid-index? ((sh <array-base>) (ind <s32vector>))
  (receive (Vb Ve) (shape->start/end-vector sh)
    (and
     (= (s32vector-length ind) (s32vector-length Vb))
     (not (s32vector-range-check ind Vb (s32vector-sub Ve 1))))))

(define-method shape-valid-index? ((sh <array-base>) (ind <vector>))
  (shape-valid-index? sh (vector->s32vector ind)))
(define-method shape-valid-index? ((sh <array-base>) (ind <pair>))
  (shape-valid-index? sh (list->s32vector ind)))
(define-method shape-valid-index? ((sh <array-base>) (ind <array>))
  (shape-valid-index? sh (vector->s32vector (array->vector ind))))
(define-method shape-valid-index? ((sh <array-base>) (ind <integer>) . more-index)
  (shape-valid-index? sh (list->s32vector (cons ind more-index))))
(define-method shape-valid-index? ((sh <array-base>))
  ;; special case - zero dimensional array
  (null? (array->list sh)))

;; array index tests

(define (array-valid-index? ar . args)
  (apply shape-valid-index? (array-shape ar) args))

;;---------------------------------------------------------------
;; Shape
;;   ... is a special array that has a few constraints;
;;   that is, suppose S is a shape, then:
;;   (1) the shape of S is [0, d, 0, 2], where d is the rank
;;       of the array S represents.
;;   (2) all the elements of S are exact integers.
;;   (3) (array-ref S n 0) <= (array-ref S n 1) where 0 <= n <= d
;;

(define (shape . args)
  (let1 arglen (length args)
    (unless (even? arglen)
      (error "shape arguments not even" args))
    (let* ([rank (quotient arglen 2)]
           [back (make-vector arglen)])
      ;; check vailidity of the shape
      (do ([l args (cddr l)])
          [(null? l)]
        (unless (and (exact? (car l)) (integer? (car l)))
          (error "exact integer required for shape, but got" (car l)))
        (unless (and (exact? (cadr l)) (integer? (cadr l)))
          (error "exact integer required for shape, but got" (cadr l)))
        (unless (<= (car l) (cadr l))
          (errorf "beginning index ~s is larger than ending index ~s in shape argument: ~s" (car l) (cadr l) args)))
      ;; make array.
      (let* ([Vb (s32vector 0 0)]
             [Ve (s32vector rank 2)]
             [Vc (coefficient-vector Vb Ve)])
        (make <array>
          :start-vector Vb
          :end-vector Ve
          :coefficient-vector Vc
          :backing-storage (list->vector args))))))

(define (shape->start/end-vector shape)
  (let* ([rank (array-end shape 0)]
         [cnt  (iota rank)])
    (values (map-to <s32vector> (^i (array-ref shape i 0)) cnt)
            (map-to <s32vector> (^i (array-ref shape i 1)) cnt))))

(define (start/end-vector->shape Vb Ve)
  (define (interleave a b)
    (cond [(null? a) b]
          [(null? b) a]
          [else (cons (car a) (interleave b (cdr a)))]))
  (apply shape (interleave (s32vector->list Vb) (s32vector->list Ve))))

;;---------------------------------------------------------------
;; Make general array
;;

(define (make-array-internal class shape
                             :key init-1 init-list)
  (receive (Vb Ve) (shape->start/end-vector shape)
    (let ([Vc (coefficient-vector Vb Ve)]
          [bsclass (~ class'backing-storage-class)]
          [bslen (fold * 1 (s32vector-sub Ve Vb))])
      (make class
        :start-vector Vb
        :end-vector Ve
        :coefficient-vector Vc
        :backing-storage
        (cond [(pair? init-list)
               (unless (= (%array-size Vb Ve) (length init-list))
                 (error "Array initialization list doesn't match array size"
                        init-list))
               (coerce-to bsclass init-list)]
              [(undefined? init-1)
               (if (eq? bsclass <vector>)
                 (make-vector bslen)
                 (make-uvector bsclass bslen))]
              [else
               (if (eq? bsclass <vector>)
                 (make-vector bslen init-1)
                 (make-uvector bsclass bslen init-1))])))))

(define-macro (define-array-ctors . pfxs)
  (define (build-ctor-def pfx)
    (define maker (symbol-append 'make- pfx 'array))
    (define ctor  (symbol-append pfx 'array))
    (define class (symbol-append '< pfx 'array>))
    `((define (,maker shape :optional init-1)
        (make-array-internal ,class shape :init-1 init-1))
      (define (,ctor shape . inits)
        (make-array-internal ,class shape :init-list inits))))
  `(begin ,@(append-map build-ctor-def pfxs)))

(define-array-ctors || u8 s8 u16 s16 u32 s32 u64 s64 f16 f32 f64 c32 c64 c128)

(define (subarray ar sh)
  (receive (Vb Ve) (shape->start/end-vector sh)
    (let* ([rank (s32vector-length Vb)]
           [Vb2 (make-s32vector rank 0)]
           [Ve2 (s32vector-sub Ve Vb)]
           [new-shape (start/end-vector->shape Vb2 Ve2)]
           [res (make-array-internal (class-of ar) new-shape)])
      (array-retabulate! res new-shape
                         (^[ind] (array-ref ar (s32vector->vector
                                                (s32vector-add Vb ind))))
                         (make-vector rank))
      res)))

(define (array? obj)
  (is-a? obj <array-base>))

;; array-rank is in src/libarray.scm

(define (array-start array k)
  (unless (array? array) (error "array required, but got" array))
  (s32vector-ref (start-vector-of array) k))

(define (array-end array k)
  (unless (array? array) (error "array required, but got" array))
  (s32vector-ref (end-vector-of array) k))

;;---------------------------------------------------------------
;; Array ref and set!
;;

;; array-ref and array-set! are now defined in src/libarray.scm

;;---------------------------------------------------------------
;; Share array
;;

;; PROC maps the resulting array's index Ia = [a0 a1 a2 ... an-1] to
;; the original array's index Ib = [b0 b1 ... bm-1].  The mapping must be
;; an affine transformation, i.e.
;;   Ib = X・Ia + Y
;; where X is nxm matrix and Y is a vector of length m.
;;
;; Note that the origial indices Ib is mapped to the 1-dimensional
;; backing storage position with
;;   Vc・(Ib - Vb) + off
;; where Vc is the original coefficient vector, Vb is the original
;; start vector, and off is the original offset.
;; Thus, given Ia, we can say
;;
;;   Vc・(X・Ia + Y - Vb) + off
;;     = (Vc・X)・Ia + Vc・(Y - Vb) + off
;;
;; As the accessor of the new array computes the position from
;; the given index by
;;
;;   Vc'・(Ia - Vb') + off'
;;
;; where Vb' is the start vector of new array (given by the shape),
;; Vc' is the new coefficient vector, and off' is the new offset.
;; Hence we can compute Vc' and off' as follows.
;;
;;   Vc'  = Vc・X
;;   off' = (Vc・X)・Vb' + Vc・(Y - Vb) + off
;;

;; Compute X and Y
(define (affine-proc->coeffs proc rank)
  (receive Y (apply proc (make-list rank 0))
    (let ([X (map (^_ (make-list rank 0)) Y)]
          [cnt   (iota rank)])
      (dotimes [i rank]
        (receive Ks (apply proc (map (^j (if (= j i) 1 0)) cnt))
          (for-each (^[v k c] (set! (~ v i) (- k c))) X Ks Y)))
      (values (if (null? Y)
                '()
                (apply map s32vector X)) ; make list of column vector
              (coerce-to <s32vector> Y)))))

;; compute Vc・X.  Note that X is a list of column vectors.
(define (mul-vec-mat Vc X)
  (map-to <s32vector> (cut s32vector-dot Vc <>) X))

(define (compute-new-coeff+offset oarray Vb. X Y)
  (let* ([Vc (~ oarray'coefficient-vector)]
         [Vb (~ oarray'start-vector)]
         [off (~ oarray'offset)]
         [Vc. (mul-vec-mat Vc X)]
         [off. (if (null? X)
                 0                      ;empty array case
                 (+ (s32vector-dot Vc. Vb.)
                    (s32vector-dot Vc (s32vector-sub Y Vb))
                    off))])
    (values Vc. off.)))

(define (share-array array shape proc)
  (receive (Vb. Ve.) (shape->start/end-vector shape)
    (receive (X Y) (affine-proc->coeffs proc (size-of Vb.))
      (receive (Vc. off.) (compute-new-coeff+offset array Vb. X Y)
        (make (class-of array)
          :start-vector Vb.
          :end-vector   Ve.
          :coefficient-vector Vc.
          :offset off.
          :backing-storage (~ array'backing-storage))))))

;;---------------------------------------------------------------
;; Array utilities
;;

(define (array-shape ar)
  (let1 r (array-rank ar)
    (apply array (shape 0 r 0 2)
           (append-map! (^k (list (array-start ar k) (array-end ar k)))
                        (iota r)))))

(define (array-length ar dim)
  (- (array-end ar dim) (array-start ar dim)))

(define (%array-size Vb Ve) (reduce * 1 (map - Ve Vb)))

(define (array-size ar) (%array-size (~ ar'start-vector) (~ ar'end-vector)))

(define (array-equal? a b :optional (eq equal?))
  (let1 r (array-rank a)
    (and (= r (array-rank b))
         (every (^[dim] (and (= (array-start a dim) (array-start b dim))
                             (= (array-end a dim) (array-end b dim))))
                (iota r))
         (let/cc break
           (array-for-each-index a
             (^[index] (unless (eq (array-ref a index)
                                   (array-ref b index))
                         (break #f)))
             (make-vector r))
           #t))))

(define-method object-equal? ((a <array-base>) (b <array-base>))
  (array-equal? a b equal?))

;; returns a proc that applies proc to indices that is given by a vector.

(define (array-index-applier rank)
  (case rank
    [(0) (^[proc vec] (proc))]
    [(1) (^[proc vec] (proc (vector-ref vec 0)))]
    [(2) (^[proc vec] (proc (vector-ref vec 0) (vector-ref vec 1)))]
    [(3) (^[proc vec]
           (proc (vector-ref vec 0) (vector-ref vec 1) (vector-ref vec 2)))]
    [else (^[proc vec] (apply proc (vector->list vec)))]))

(define (array-for-each proc ar)
  (for-each proc (~ ar'backing-storage)))

(define (array-any pred ar)
  (let/cc found
    (for-each (^x (if (pred x) (found #t))) (~ ar'backing-storage))
    #f))

(define (array-every pred ar)
  (let/cc found
    (for-each (^x (if (not (pred x)) (found #f))) (~ ar'backing-storage))
    #t))

;; repeat construct

(define (array-for-each-int proc keep Vb Ve ind)

  (define i (if (pair? ind) (car ind) (make-vector (s32vector-length Vb))))
  (define applier #f)

  (define (list-loop dim k)
    (if (= dim (car k))
      (let ([e (s32vector-ref Ve dim)]
            [rest (cdr k)])
        (if (null? rest)
          (do ([k (s32vector-ref Vb dim) (+ k 1)])
              [(= k e)]
            (vector-set! i dim k)
            ;; use an applier
            (applier proc i))
          (do ([k (s32vector-ref Vb dim) (+ k 1)])
              [(= k e)]
            (vector-set! i dim k)
            (list-loop (+ dim 1) rest))))
      (list-loop (+ dim 1) k)))

  (define (helper-loop setter dimensions keep-ls)
    (let loop ([dim dimensions]
               [k keep-ls])
      (if (= dim (car k))
        ;; we loop over this dimension
        (let ([e (s32vector-ref Ve dim)]
              [rest (cdr k)])
          (if (null? rest)
            ;; inline last loop to avoid excess procedure calls
            (do ([k (s32vector-ref Vb dim) (+ k 1)])
                [(= k e)]
              (setter i dim k)
              (proc i))
            ;; set the index for this dimension and loop
            (do ([k (s32vector-ref Vb dim) (+ k 1)])
                [(= k e)]
              (setter i dim k)
              (loop (+ dim 1) rest))))
        ;; skip this dimension
        (loop (+ dim 1) k))))

  (unless (null? keep)
    (cond
      [(null? ind)
       (set! applier (array-index-applier (s32vector-length Vb)))
       (list-loop 0 keep)]
      [(vector? i)      (helper-loop vector-set! 0 keep)]
      [(array? i)       (helper-loop array-set! 0 keep)]
      [(s8vector? i)    (helper-loop s8vector-set! 0 keep)]
      [(s16vector? i)   (helper-loop s16vector-set! 0 keep)]
      [(s32vector? i)   (helper-loop s32vector-set! 0 keep)]
      [else "bad index object (vector or array required)" (car ind)])))

(define (array-for-each-index ar proc . o)
  (array-for-each-int
   proc
   (iota (array-rank ar))
   (start-vector-of ar)
   (end-vector-of ar)
   o))

(define (array-for-each-index-by-dimension ar keep proc . o)
  (array-for-each-int
   proc
   keep
   (start-vector-of ar)
   (end-vector-of ar)
   o))

(define (shape-for-each sh proc . o)
  (let* ([rank (array-end sh 0)]
         [ser  (iota rank)])
    (array-for-each-int
     proc
     ser
     (map-to <s32vector> (cut array-ref sh <> 0) ser)
     (map-to <s32vector> (cut array-ref sh <> 1) ser)
     o)))

(define (tabulate-array sh . args)
  (rlet1 res (make-array sh)
    (apply array-retabulate! res args)))

;; Mapping onto array.
;;   array-retabulate!
;;   array-map!
;;   array-map
;; These may take optional shape argument.  It is redundant, for the shape
;; has to match the target array's shape anyway.  In Jussi's reference
;; implementation, giving the shape allows some optimization.  In my
;; implementation it's not much of use.  I keep it just for compatibility
;; to Jussi's.

(define-method array-retabulate! ((ar <array-base>) (sh <array-base>) (proc <procedure>) . o)
  ;; need to check the shape sh matches the ar's shape.
  (apply array-retabulate! ar proc o))

(define-method array-retabulate! ((ar <array-base>) (proc <procedure>) . o)
  (cond
   [(null? o)
    (let1 applier (array-index-applier (array-rank ar))
      (array-for-each-index ar
                            (^[ind] (array-set! ar ind (applier proc ind)))
                            (make-vector (array-rank ar))))]
   [(or (vector? (car o)) (array? (car o)))
    (array-for-each-index ar
                          (^[ind] (array-set! ar ind (proc ind)))
                          (car o))]
   [else "bad index object (vector or array required)" (car o)]))

(define-method array-map! ((ar <array-base>) (sh <array-base>) (proc <procedure>) ar0 . more-arrays)
  ;; need to check the shape sh matches the ar's shape.
  (apply array-map! ar proc ar0 more-arrays))

(define-method array-map! ((ar <array-base>) (proc <procedure>) ar0)
  ($ array-for-each-index ar0
     (^[ind] (array-set! ar ind (proc (array-ref ar0 ind))))
     (make-vector (array-rank ar))))

(define-method array-map! ((ar <array-base>) (proc <procedure>) ar0 ar1)
  ($ array-for-each-index ar0
     (^[ind] (array-set! ar ind (proc (array-ref ar0 ind)
                                      (array-ref ar1 ind))))
     (make-vector (array-rank ar))))

(define-method array-map! ((ar <array-base>) (proc <procedure>) ar0 ar1 ar2 . more)
  (let1 arlist (list* ar0 ar1 ar2 more)
    (array-for-each-index ar
      (^[ind] (array-set! ar ind
                          (apply proc
                                 (map (cut array-ref <> ind) arlist))))
      (make-vector (array-rank ar)))))

(define-method array-map ((sh <array-base>) (proc <procedure>) ar0 . more)
  (apply array-map proc ar0 more))

(define-method array-map ((proc <procedure>) ar0 . more)
  (rlet1 target (make-array (array-shape ar0))
    (apply array-map! target proc ar0 more)))

(define (array->vector ar)
  (if (zero? (array-rank ar))
    (vector (array-ref ar))
    (with-builder (<vector> add! get :size (array-size ar))
      ($ array-for-each-index ar
         (^[ind] (add! (array-ref ar ind)))
         (make-vector (array-rank ar)))
      (get))))

(define (array->list ar)
  (if (zero? (array-rank ar))
    (list (array-ref ar))
    (with-builder (<list> add! get)
      ($ array-for-each-index ar
         (^[ind] (add! (array-ref ar ind)))
         (make-vector (array-rank ar)))
      (get))))

(define (array->nested-list ar)
  (define rank (array-rank ar))
  (define ind (make-vector rank 0))
  (define (rec axis)
    (define dim (array-length ar axis))
    (define start (array-start ar axis))
    (define get
      (if (= axis (- rank 1))
        (^i (vector-set! ind axis i)
            (array-ref ar ind))
        (^i (vector-set! ind axis i)
            (rec (+ axis 1)))))
    (do ([i (- (array-end ar axis) 1) (- i 1)]
         [r '() (cons (get i) r)])
        [(< i start) r]))
  (if (zero? rank)
    (array-ref ar)
    (rec 0)))

(define (array->nested-vector ar)
  (define rank (array-rank ar))
  (define ind (make-vector rank 0))
  (define (rec axis)
    (define dim (array-length ar axis))
    (define start (array-start ar axis))
    (define storage (make-vector dim))
    (define fill!
      (if (= axis (- rank 1))
        (^[i j] (vector-set! ind axis i)
          (vector-set! storage j (array-ref ar ind)))
        (^[i j] (vector-set! ind axis i)
          (vector-set! storage j (rec (+ axis 1))))))
    (do ([i (- (array-end ar axis) 1) (- i 1)]
         [j (- dim 1) (- j 1)])
        [(< i start) storage]
      (fill! i j)))
  (rec 0))
