1 /*
2 Copyright (c) 2022 Zenw
3 Released under the MIT license
4 https://opensource.org/licenses/mit-license.php
5 */
6 module ggeD.einsum;
7 
8 import std;
9 import ggeD.ggeD;
10 
11 
12 unittest
13 {
14    auto A = iota(9.).gged!double(3,3);
15     auto x = A[0,0..$];
16     assert(A == [[0, 1, 2],[3, 4, 5],[6, 7, 8]]);
17     assert(x == [0, 1, 2]);
18 
19     auto Ax = Einsum | A.ij * x.i;
20     assert(Ax == [15,18,21]);
21 
22     auto tr = Einsum | A.ii;
23     assert(tr == 12);
24 
25     auto transposed = Einsum.ji | A.ij;
26     assert(transposed == [[0, 3, 6], [1, 4, 7], [2, 5, 8]]);
27 
28     auto delta = fnTensor((ulong i,ulong j)=>(i==j?1.:0.));
29     auto tr2 = Einsum | A.ij*delta.ij;
30     assert(tr2 == 12);
31 
32     auto applyFunction = Einsum | br!tan(br!atan(A.ij)*1.);
33     assert(A == applyFunction);
34 
35     auto applyFunction2 = Einsum | br!atan2(A[0,0..3].i,1.+A[0..3,0].i);
36     assert(applyFunction2 == atan2(A[0,0],1+A[0,0]) + atan2(A[0,1],1+A[1,0]) + atan2(A[0,2],1+A[2,0]) );
37 }
38 
39 
40 @nogc auto br(alias fun,Args...)(Args args)
41 {
42     return Func!("",fun,Args)(args);
43 }
44 
45 
46 
47 @nogc auto fnTensor(F)(F fun) if(isCallable!F)
48 {
49     struct Sub
50     {
51         auto opDispatch(string idx)() 
52         {
53             return FnTensor!(idx,F)(fun);
54         }
55     }
56     return Sub();
57 }
58 
59 
60 struct Einsum
61 {
62     static:
63     auto evalnogc(string ResultIdx="",Node)(Node leaf) @nogc nothrow
64     {
65         import mir.ndslice;
66         import ggeD.iterator;
67         alias indexes  = getIndex!(Node,ResultIdx);
68         alias Origin = TemplateOf!Node;
69         alias Args = TemplateArgsOf!Node;
70         alias NewOne =  Origin!(ResultIdx,Args[1..$]);
71         auto newone = NewOne(leaf.tupleof);
72         static if(indexes[0].length > 0)
73         {
74             alias shapes= Alias!(getShapes!(NewOne,"newone",indexes[0]));
75             auto arr =  mixin("iota("~shapes~")");
76             size_t[indexes[0].length] shape = mixin("["~shapes~"]");
77             auto itr = EinsumIterator!(typeof(arr._iterator),typeof(arr),NewOne)(arr._iterator,arr,newone);
78             return Slice!(typeof(itr),indexes[0].length)(shape,itr).gged;
79         }
80         else
81         {
82             return newone.calc();
83         }
84     }
85 
86     auto opBinary(string op,R)(R rhs) @nogc nothrow  if(op == "|") 
87     {
88         return evalnogc(rhs);
89     }
90     
91     auto opDispatch(string name)()  @nogc nothrow 
92     { 
93         struct Sub
94         {
95             auto opBinary(string op,R)(R rhs) if(op == "|")
96             {
97                 return evalnogc!name(rhs);
98             }
99         }
100 
101         return Sub();
102     }
103    
104 }
105 
106 package(ggeD):
107 string onlyUniq(string input,string ignr="")
108 {
109     return ignr != "" ? ignr : input.to!(dchar[]).filter!(a=>input.count(a) == 1).array.to!string;
110 }
111 string onlyDummy(string input,string ignr="")
112 {
113     return ignr != "" ? input.to!(dchar[]).filter!(a=>ignr.count(a)==0).array.sort.uniq.array.to!string :  input.to!(dchar[]).filter!(a=>input.count(a) > 1 ).array.sort.uniq.array.to!string;
114 }
115 
116 
117 template filterTensors(Leafs...)
118 {
119     static if(Leafs.length == 1)
120     {
121         static if(is(Leafs[0] == Leaf!(Ridx,idx,T),string Ridx,string idx,T) || isBasicType!(Leafs[0]))
122         {
123             alias filterTensors = Leafs[0];
124         }
125         else static if(is(Leafs[0] == Func!(Ridx,func,Leafs_),string Ridx,alias func,Leafs_...))
126         {
127             alias filterTensors = filterTensors!(Leafs_);
128         }
129         else static if(is(Leafs[0] == Tree!(Ridx,L,R,OP,Leafs_),string Ridx,L,R,string OP,Leafs_...))
130         {
131             alias filterTensors = filterTensors!(Leafs_);
132         }
133         else 
134         {
135             alias filterTensors = AliasSeq!();
136         }
137     }
138     else static if(Leafs.length == 0)
139     {
140         alias filterTensors = AliasSeq!();
141     }
142     else
143     {
144         alias filterTensors = AliasSeq!(filterTensors!(Leafs[0]),filterTensors!(Leafs[1..$]));
145     }
146 }
147 string removeCharacters(string A, string B)
148 {
149     // 文字列Bに含まれる文字をフィルタリングして削除する
150     return A.filter!(c => !B.canFind(c)).array.to!string;
151 }
152 
153 template getIndex(Node,string ignr = "")
154 {
155     template getunq(Leaf)
156     {
157         alias getunq = Alias!(getIndex!(Leaf,ignr)[0]);
158     }
159     template getdmy(Leaf)
160     {
161         alias getdmy = Alias!(getIndex!(Leaf,ignr)[1]);
162     }
163     static if(is(Node == Tree!(Ridx,L,R,OP,Leafs),string Ridx,L,R,string OP,Leafs...))
164     {
165         alias LHS = getIndex!(L,ignr);
166         alias RHS = getIndex!(R,ignr);
167         static if(OP == "*" || OP == "/")
168         {
169             // alias getIndex = AliasSeq!((LHS[0]~RHS[0]~LHS[1]~RHS[1]).onlyUniq(ignr),((LHS[0]~RHS[0]).onlyDummy(ignr)~LHS[1]~RHS[1]).array.sort.uniq.array.to!string);
170             alias getIndex = AliasSeq!((LHS[0]~RHS[0]~LHS[1]~RHS[1]).onlyUniq(ignr).removeCharacters(LHS[1]~RHS[1]),((LHS[0]~RHS[0]).onlyDummy(ignr)~LHS[1]~RHS[1]).array.sort.uniq.array.to!string);
171             // alias getIndex = AliasSeq!((LHS[0]~RHS[0]~LHS[1]~RHS[1]).onlyUniq(ignr),(LHS[0]~RHS[0]~LHS[1]~RHS[1]).onlyDummy(ignr));
172         }
173         else 
174         {
175             alias getIndex = AliasSeq!((LHS[0]~RHS[0]).array.sort.uniq.array.to!string ,"");
176             // alias getIndex = AliasSeq!((LHS[0]~RHS[0]).onlyUniq(ignr),((LHS[0]~RHS[0]).onlyDummy(ignr)~LHS[1]~RHS[1]).array.sort.uniq.array.to!string);
177         }
178     }
179     else static if(is(Node == Leaf!(Ridx,idx,Ts),string Ridx,string idx,Ts))
180     {
181         alias getIndex = AliasSeq!(idx.onlyUniq(ignr),idx.onlyDummy(ignr));
182     }
183     else static if(is(Node == Func!(Ridx,fun,Leafs),string Ridx,alias fun,Leafs...))
184     {
185         alias UNQs = staticMap!(getunq,Leafs);
186         alias DMYs = staticMap!(getdmy,Leafs);
187         alias getIndex = AliasSeq!(join([UNQs,DMYs]).onlyUniq(ignr),join([UNQs,DMYs]).onlyDummy(ignr));
188     }
189     else static if(is(Node == FnTensor!(idx,F),string idx,F))
190     {
191         alias getIndex = AliasSeq!(idx.onlyUniq(ignr),idx.onlyDummy(ignr));
192     }
193     else
194     {
195         alias getIndex = AliasSeq!("","");
196     }
197 }
198 
199 template getShapes(Node,string This,string ijk)
200 {
201     static if(ijk.length == 1)
202     {
203         alias getShapes = Alias!(Node.getShape!(This,ijk));
204     }
205     else static if(ijk.length == 0)
206     {
207         alias getShapes = Alias!"";
208     }
209     else
210     {
211         alias getShapes = Alias!(getShapes!(Node,This,ijk[0..1]) ~ "," ~ getShapes!(Node,This,ijk[1..$]));
212     }
213 }
214 template getExp(string This,Node,ulong N)
215 {
216     static if(is(Node == Tree!(Ridx,L,R,OP,Leafs),string Ridx,L,R,string OP,Leafs...))
217     {
218             alias LHS = getExp!(This~"._lhs",L,N);
219             alias RHS = getExp!(This~"._rhs",R,LHS[1]);
220         static if(OP == "+" || OP =="-")
221         {
222             alias argsL  = Alias!(getIndex!(L,Ridx)[0].map!"[a]".join(",").to!string);
223             alias argsR  = Alias!(getIndex!(R,Ridx)[0].map!"[a]".join(",").to!string);
224             alias getExp = AliasSeq!(This~"._lhs.calc("~argsL~")"  ~ OP ~ This~"._rhs.calc("~argsR~")" ,RHS[1]);
225         }
226         else
227         {
228             alias getExp = AliasSeq!("("~LHS[0]~")" ~ OP ~ "("~RHS[0]~")",RHS[1]);
229         }
230     }
231     else static if(is(Node == Leaf!(Ridx,idx,Ts),string Ridx,string idx,Ts))
232     {
233         alias ijk = Alias!(idx.map!(c=>[c]).join(",").to!string);
234         alias getExp = AliasSeq!(This~".tensor["~ijk~"]",N+1);
235     }
236     else static if(is(Node == Func!(Ridx,fun,Leafs),string Ridx,alias fun,Leafs...))
237     {
238         alias getExp = AliasSeq!(Node.asExp!(This),N+1);
239     }
240     else static if(is(Node == FnTensor!(idx,F),string idx,F))
241     {
242         alias ijk = Alias!(idx.map!(c=>[c]).join(",").to!string);
243         alias getExp = AliasSeq!(This~".FUN("~ijk~")",N+1);
244     }
245     else
246     {
247         alias getExp = AliasSeq!(This,N+1);
248     }
249 }
250 template CommonTypeOfTensors(Leafs...)
251 {
252     import std.traits;
253     template getType(T)
254     {
255         static if(isBasicType!T)
256             alias getType = T;
257         else
258             alias getType = T.Type;
259     }
260     alias Types = staticMap!(getType,filterTensors!Leafs);
261     alias CommonTypeOfTensors = CommonType!(Types);
262 }
263 
264 struct Tree(string ResultIdx ="",LHS,RHS,string op,Leafs...) 
265 {
266     LHS _lhs;
267     RHS _rhs;
268     Leafs leafs;
269     alias LeafTypes = Leafs;
270     this(LHS lhs_,RHS rhs_,Leafs leafs_)
271     {
272         _lhs = lhs_;
273         _rhs = rhs_;
274         leafs = leafs_;
275     }
276     
277     auto opBinary(string op,R)(R ohs) @nogc nothrow 
278     {
279         static if(is(R == Tree!(Ridx,Lhs,Rhs,OP,aLeafs),string Ridx,Lhs,Rhs,string OP,aLeafs...))
280             return Tree!("",typeof(this),R,op,Leafs,aLeafs)(this,ohs,leafs,ohs.leafs);
281         else
282             return Tree!("",typeof(this),R,op,Leafs,R)(this,ohs,leafs,ohs);
283     }
284     
285     auto opBinaryRight(string op,R)(R ohs) @nogc nothrow 
286     {
287         return Tree!("",R,typeof(this),op,R,Leafs)(ohs,this,ohs,leafs);
288     }
289     
290     auto opUnary(string op)() @nogc nothrow 
291     {
292         return Tree!("",Type,typeof(this),op,Type,Leafs)(0,this,0,leafs);
293     }
294 
295     alias Type = CommonTypeOfTensors!(Leafs);
296 
297     mixin(genCalc);
298     static auto genCalc()
299     {
300         alias indexes  = getIndex!(typeof(this),ResultIdx);
301         string UNIQ = (indexes[0].map!(c=>"size_t "~[c]).join(",").to!string);
302         string DMMY = (indexes[1].map!(c=>[c]).join(",").to!string);
303         string result= "auto calc("~ UNIQ ~") @nogc nothrow{\n";
304         static if(false && (op=="+"||op=="-") && (getIndex!(LHS,ResultIdx)[1].length > 0 || getIndex!(RHS,ResultIdx)[1].length > 0))
305         {
306             string argsL  = getIndex!(LHS,ResultIdx)[0].map!"[a]".join(",").to!string;
307             string argsR  = getIndex!(RHS,ResultIdx)[0].map!"[a]".join(",").to!string;
308             result ~= "\t return _lhs.calc("~argsL~")" ~ op ~"_rhs.calc("~argsR~");\n";
309             result ~="}\n";
310         }
311         else
312         {
313             result ~= "\t size_t["~indexes[1].length.to!string~"] shape = [" ~ getShapes!(typeof(this),"this",indexes[1]) ~"];\n";
314             result ~= "\t auto result = cast(Type)0;\n";
315             if(DMMY.length > 0)
316             {
317                 static foreach(i,ijk;indexes[1])
318                 {
319                     result ~= "\tforeach("~ijk~";0..shape["~i.to!string~"])\n";
320                 }
321             }
322             result~= "\t\tresult += "~getExp!("this",typeof(this),0)[0]~";\n";
323             result~= "\treturn result;\n";
324             result~= "}\n";
325         }
326        return result;
327     }
328 
329     static auto getShape(string This = "this",string ijk)()
330     {
331         static foreach(i,Node;Leafs)
332         {{ 
333             static if(__traits(hasMember,Node,"getShape"))
334             {
335                 string result = Node.getShape!(This~".leafs["~i.to!string~"]",ijk);
336             }
337             else
338             {
339                 string result = "";
340             }
341             if(result != "")
342             {
343                 return result;
344             }
345         }}
346         return "";
347     }
348 
349 }
350 
351 
352 
353 struct Leaf(string ResultIdx ="",string idx,aTensor)
354 {
355     this(aTensor t)
356     {
357         tensor = t;
358     }
359 
360     static auto getShape(string This = "this",string ijk)()
361     {
362         static if(idx.countUntil(ijk) >= 0)
363         {
364             return This ~ ".shape["~idx.countUntil(ijk).to!string ~"]";
365         }
366         else
367         {
368             return "";
369         }
370     }
371 
372     
373 
374     aTensor tensor;
375     alias Type = aTensor.Type;
376     @nogc nothrow typeof(this)[1] leafs() => [this];
377     alias LeafTypes = AliasSeq!(typeof(this));
378 
379 
380     @nogc nothrow auto shape() => tensor.shape;
381     auto opBinary(string op,R)(R ohs)  @nogc nothrow 
382     {
383         static if(is(R == Tree!(Ridx,Lhs,Rhs,OP,Leafs),string Ridx,Lhs,Rhs,string OP,Leafs...))
384             return Tree!("",typeof(this),R,op,typeof(this),Leafs)(this,ohs,this,ohs.leafs);
385         else
386             return Tree!("",typeof(this),R,op,typeof(this),R)(this,ohs,this,ohs);
387     }
388     auto opBinaryRight(string op,R)(R ohs)  @nogc nothrow 
389     {
390         static if(is(R == Tree!(Ridx,Lhs,Rhs,OP,Leafs),string Ridx,Lhs,Rhs,string OP,Leafs...))
391             return Tree!("",R,typeof(this),op,Leafs,typeof(this))(this,ohs,ohs.leafs,this);
392         else
393             return Tree!("",R,typeof(this),op,R,typeof(this))(ohs,this,ohs,this);
394     }
395     auto opUnary(string op)() @nogc nothrow 
396     {
397         return Tree!("",Type,typeof(this),op,typeof(this))(cast(Type)0,this,this);
398     }
399     mixin(genCalc);
400     static auto genCalc()
401     {
402         alias indexes  = getIndex!(typeof(this),ResultIdx);
403         string UNIQ = (indexes[0].map!(c=>"size_t "~[c]).join(",").to!string);
404         string DMMY = (indexes[1].map!(c=>[c]).join(",").to!string);
405         string result="auto calc(" ~ UNIQ ~ ") @nogc nothrow {\n";
406         result ~= "\t size_t["~indexes[1].length.to!string~"] shape = [" ~ getShapes!(typeof(this),"this",indexes[1]) ~"];\n";
407         result ~= "\t auto result = cast(Type)0;\n";
408 
409         if(DMMY.length > 0)
410         {
411             static foreach(i,ijk;indexes[1])
412             {
413                 result ~= "\tforeach("~ijk~";0..shape["~i.to!string~"])\n";
414             }
415         }
416         alias ijk = Alias!(idx.map!(c=>[c]).join(",").to!string);
417         result~= "\t\tresult += tensor["~ijk~"];\n";
418         
419         result~= "\treturn result;\n";
420         result~= "}\n";
421         return result;
422     }
423 
424 }
425 
426 
427 struct Func(string ResultIdx ="",alias fun,Leafs...)
428 {
429     alias FUN = fun;
430     static if(is(ReturnType!fun))
431         alias Type = ReturnType!fun;
432     else 
433         alias Type = CommonTypeOfTensors!Leafs;
434     alias ArgLength = Alias!(Leafs.length); 
435     Leafs leafs;
436     this(Leafs args_)
437     {
438         leafs = args_;
439     }
440     
441     mixin(genCalc);
442     static auto genCalc()
443     {
444         alias indexes  = getIndex!(typeof(this),ResultIdx);
445         string UNIQ = (indexes[0].map!(c=>"size_t "~[c]).join(",").to!string);
446         string DMMY = (indexes[1].map!(c=>[c]).join(",").to!string);
447         string result="auto calc(" ~ UNIQ ~ ") @nogc nothrow {\n";
448         result ~= "\t size_t["~indexes[1].length.to!string~"] shape = [" ~ getShapes!(typeof(this),"this",indexes[1]) ~"];\n";
449         result ~= "\t auto result = cast(Type)0;\n";
450 
451         if(DMMY.length > 0)
452         {
453             static foreach(i,ijk;indexes[1])
454             {
455                 result ~= "\tforeach("~ijk~";0..shape["~i.to!string~"])\n";
456             }
457         }
458         result~= "\t\tresult += "~asExp~";\n";
459         result~= "\treturn result;\n";
460         result~= "}\n";
461         return result;
462     }
463     static auto getShape(string This = "this",string ijk)()
464     {
465         static foreach(i,Node;Leafs)
466         {{ 
467             static if(__traits(hasMember,Node,"getShape"))
468             {
469                 auto result = Node.getShape!(This~".leafs["~i.to!string~"]",ijk);
470             }
471             else
472             {
473                 auto result = "";
474             }
475             if(result != "")
476             {
477                 return result;
478             }
479         }}
480         return "";
481     }
482     
483     static auto asExp(string This = "this")()
484     {
485         string result = This~".FUN(";
486         static foreach(i,arg;leafs)
487         {{ 
488             static if(is(Leafs[i] == Tree!(Ridx,LHS,RHS,op,Leafs_),string Ridx,LHS,RHS,string op,Leafs_...))
489             {
490                 result ~= getExp!(This~".leafs["~i.to!string~"]",Leafs[i],0)[0] ~ ",";
491             }
492             else static if(is(Leafs[i] == Leaf!(Ridx,idx,Tns),string Ridx,string idx,Tns))
493             {
494                 alias ijk = Alias!(idx.map!(c=>[c]).join(",").to!string);
495                 result ~= This~".leafs["~i.to!string~"].tensor[" ~ijk~"],";
496             }
497             else static if(is(Leafs[i] == Func!(Ridx,fun_,Leafs_),string Ridx,alias fun_,Leafs_...))
498             {
499                 result ~= Leafs[i] .asExp!("leafs["~i.to!string~"]")~",";
500             }
501             else static if(is(Leafs[i] == FnTensor!(idx,F),string idx,F))
502             {
503                 alias ijk = Alias!(idx.map!(c=>[c]).join(",").to!string);
504                 result ~= Leafs[i]~".leafs["~i.to!string~"].FUN("~ijk~"),";
505             }
506             else
507             {
508                 result ~= This~".leafs["~i.to!string~"],";
509             }
510         }}
511         result ~= ")";
512         return result;
513     }
514     auto opBinary(string op,R)(R ohs)  @nogc nothrow 
515     {
516         static if(is(R == Tree!(Ridx,Lhs,Rhs,OP,Leafs_),string Ridx,Lhs,Rhs,string OP,Leafs_...))
517             return Tree!("",typeof(this),R,op,typeof(this),Leafs_)(this,ohs,this,ohs.leafs);
518         else
519             return Tree!("",typeof(this),R,op,typeof(this),R)(this,ohs,this,ohs);
520     }
521     auto opBinaryRight(string op,R)(R ohs)  @nogc nothrow 
522     {
523         static if(is(R == Tree!(Ridx,Lhs,Rhs,OP,Leafs_),string Ridx,Lhs,Rhs,string OP,Leafs_...))
524             return Tree!("",typeof(this),R,op,Leafs_,typeof(this))(this,ohs,ohs.leafs,this);
525         else
526             return Tree!("",typeof(this),R,op,R,typeof(this))(this,ohs,ohs,this);
527     }
528 }
529 
530 struct FnTensor(string idx,F)
531 {
532     F FUN;
533     this(F)(F fun) if (isCallable!F)
534     {
535         FUN = fun;
536     }
537 
538     static auto getShape(string This = "this",string ijk)()
539     {
540         return "";
541     }
542     auto opBinary(string op,R)(R ohs) @nogc nothrow 
543     {
544         static if(is(R == Tree!(Ridx,Lhs,Rhs,OP,Leafs_),string Ridx,Lhs,Rhs,string OP,Leafs_...))
545             return Tree!("",typeof(this),R,op,typeof(this),Leafs_)(this,ohs,this,ohs.leafs);
546         else
547             return Tree!("",typeof(this),R,op,typeof(this),R)(this,ohs,this,ohs);
548     }
549     auto opBinaryRight(string op,R)(R ohs) @nogc nothrow 
550     {
551         static if(is(R == Tree!(Ridx,Lhs,Rhs,OP,Leafs_),string Ridx,Lhs,Rhs,string OP,Leafs_...))
552             return Tree!("",typeof(this),R,op,Leafs_,typeof(this))(this,ohs,ohs.leafs,this);
553         else
554             return Tree!("",typeof(this),R,op,R,typeof(this))(this,ohs,ohs,this);
555     }
556 }
557 
558 private auto calc(R)(R value)
559 {
560     return value;
561 }