1 /* 2 Copyright (c) 2022 Zenw 3 Released under the MIT license 4 https://opensource.org/licenses/mit-license.php 5 */ 6 module ggeD.einsum; 7 8 import std; 9 import ggeD.ggeD; 10 11 12 unittest 13 { 14 auto A = iota(9.).gged!double(3,3); 15 auto x = A[0,0..$]; 16 assert(A == [[0, 1, 2],[3, 4, 5],[6, 7, 8]]); 17 assert(x == [0, 1, 2]); 18 19 auto Ax = Einsum | A.ij * x.i; 20 assert(Ax == [15,18,21]); 21 22 auto tr = Einsum | A.ii; 23 assert(tr == 12); 24 25 auto transposed = Einsum.ji | A.ij; 26 assert(transposed == [[0, 3, 6], [1, 4, 7], [2, 5, 8]]); 27 28 auto delta = fnTensor((ulong i,ulong j)=>(i==j?1.:0.)); 29 auto tr2 = Einsum | A.ij*delta.ij; 30 assert(tr2 == 12); 31 32 auto applyFunction = Einsum | br!tan(br!atan(A.ij)*1.); 33 assert(A == applyFunction); 34 35 auto applyFunction2 = Einsum | br!atan2(A[0,0..3].i,1.+A[0..3,0].i); 36 assert(applyFunction2 == atan2(A[0,0],1+A[0,0]) + atan2(A[0,1],1+A[1,0]) + atan2(A[0,2],1+A[2,0]) ); 37 } 38 39 40 @nogc auto br(alias fun,Args...)(Args args) 41 { 42 return Func!("",fun,Args)(args); 43 } 44 45 46 47 @nogc auto fnTensor(F)(F fun) if(isCallable!F) 48 { 49 struct Sub 50 { 51 auto opDispatch(string idx)() 52 { 53 return FnTensor!(idx,F)(fun); 54 } 55 } 56 return Sub(); 57 } 58 59 60 struct Einsum 61 { 62 static: 63 auto evalnogc(string ResultIdx="",Node)(Node leaf) @nogc nothrow 64 { 65 import mir.ndslice; 66 import ggeD.iterator; 67 alias indexes = getIndex!(Node,ResultIdx); 68 alias Origin = TemplateOf!Node; 69 alias Args = TemplateArgsOf!Node; 70 alias NewOne = Origin!(ResultIdx,Args[1..$]); 71 auto newone = NewOne(leaf.tupleof); 72 static if(indexes[0].length > 0) 73 { 74 alias shapes= Alias!(getShapes!(NewOne,"newone",indexes[0])); 75 auto arr = mixin("iota("~shapes~")"); 76 size_t[indexes[0].length] shape = mixin("["~shapes~"]"); 77 auto itr = EinsumIterator!(typeof(arr._iterator),typeof(arr),NewOne)(arr._iterator,arr,newone); 78 return Slice!(typeof(itr),indexes[0].length)(shape,itr).gged; 79 } 80 else 81 { 82 return newone.calc(); 83 } 84 } 85 86 auto opBinary(string op,R)(R rhs) @nogc nothrow if(op == "|") 87 { 88 return evalnogc(rhs); 89 } 90 91 auto opDispatch(string name)() @nogc nothrow 92 { 93 struct Sub 94 { 95 auto opBinary(string op,R)(R rhs) if(op == "|") 96 { 97 return evalnogc!name(rhs); 98 } 99 } 100 101 return Sub(); 102 } 103 104 } 105 106 package(ggeD): 107 string onlyUniq(string input,string ignr="") 108 { 109 return ignr != "" ? ignr : input.to!(dchar[]).filter!(a=>input.count(a) == 1).array.to!string; 110 } 111 string onlyDummy(string input,string ignr="") 112 { 113 return ignr != "" ? input.to!(dchar[]).filter!(a=>ignr.count(a)==0).array.sort.uniq.array.to!string : input.to!(dchar[]).filter!(a=>input.count(a) > 1 ).array.sort.uniq.array.to!string; 114 } 115 116 117 template filterTensors(Leafs...) 118 { 119 static if(Leafs.length == 1) 120 { 121 static if(is(Leafs[0] == Leaf!(Ridx,idx,T),string Ridx,string idx,T) || isBasicType!(Leafs[0])) 122 { 123 alias filterTensors = Leafs[0]; 124 } 125 else static if(is(Leafs[0] == Func!(Ridx,func,Leafs_),string Ridx,alias func,Leafs_...)) 126 { 127 alias filterTensors = filterTensors!(Leafs_); 128 } 129 else static if(is(Leafs[0] == Tree!(Ridx,L,R,OP,Leafs_),string Ridx,L,R,string OP,Leafs_...)) 130 { 131 alias filterTensors = filterTensors!(Leafs_); 132 } 133 else 134 { 135 alias filterTensors = AliasSeq!(); 136 } 137 } 138 else static if(Leafs.length == 0) 139 { 140 alias filterTensors = AliasSeq!(); 141 } 142 else 143 { 144 alias filterTensors = AliasSeq!(filterTensors!(Leafs[0]),filterTensors!(Leafs[1..$])); 145 } 146 } 147 string removeCharacters(string A, string B) 148 { 149 // 文字列Bに含まれる文字をフィルタリングして削除する 150 return A.filter!(c => !B.canFind(c)).array.to!string; 151 } 152 153 template getIndex(Node,string ignr = "") 154 { 155 template getunq(Leaf) 156 { 157 alias getunq = Alias!(getIndex!(Leaf,ignr)[0]); 158 } 159 template getdmy(Leaf) 160 { 161 alias getdmy = Alias!(getIndex!(Leaf,ignr)[1]); 162 } 163 static if(is(Node == Tree!(Ridx,L,R,OP,Leafs),string Ridx,L,R,string OP,Leafs...)) 164 { 165 alias LHS = getIndex!(L,ignr); 166 alias RHS = getIndex!(R,ignr); 167 static if(OP == "*" || OP == "/") 168 { 169 // alias getIndex = AliasSeq!((LHS[0]~RHS[0]~LHS[1]~RHS[1]).onlyUniq(ignr),((LHS[0]~RHS[0]).onlyDummy(ignr)~LHS[1]~RHS[1]).array.sort.uniq.array.to!string); 170 alias getIndex = AliasSeq!((LHS[0]~RHS[0]~LHS[1]~RHS[1]).onlyUniq(ignr).removeCharacters(LHS[1]~RHS[1]),((LHS[0]~RHS[0]).onlyDummy(ignr)~LHS[1]~RHS[1]).array.sort.uniq.array.to!string); 171 // alias getIndex = AliasSeq!((LHS[0]~RHS[0]~LHS[1]~RHS[1]).onlyUniq(ignr),(LHS[0]~RHS[0]~LHS[1]~RHS[1]).onlyDummy(ignr)); 172 } 173 else 174 { 175 alias getIndex = AliasSeq!((LHS[0]~RHS[0]).array.sort.uniq.array.to!string ,""); 176 // alias getIndex = AliasSeq!((LHS[0]~RHS[0]).onlyUniq(ignr),((LHS[0]~RHS[0]).onlyDummy(ignr)~LHS[1]~RHS[1]).array.sort.uniq.array.to!string); 177 } 178 } 179 else static if(is(Node == Leaf!(Ridx,idx,Ts),string Ridx,string idx,Ts)) 180 { 181 alias getIndex = AliasSeq!(idx.onlyUniq(ignr),idx.onlyDummy(ignr)); 182 } 183 else static if(is(Node == Func!(Ridx,fun,Leafs),string Ridx,alias fun,Leafs...)) 184 { 185 alias UNQs = staticMap!(getunq,Leafs); 186 alias DMYs = staticMap!(getdmy,Leafs); 187 alias getIndex = AliasSeq!(join([UNQs,DMYs]).onlyUniq(ignr),join([UNQs,DMYs]).onlyDummy(ignr)); 188 } 189 else static if(is(Node == FnTensor!(idx,F),string idx,F)) 190 { 191 alias getIndex = AliasSeq!(idx.onlyUniq(ignr),idx.onlyDummy(ignr)); 192 } 193 else 194 { 195 alias getIndex = AliasSeq!("",""); 196 } 197 } 198 199 template getShapes(Node,string This,string ijk) 200 { 201 static if(ijk.length == 1) 202 { 203 alias getShapes = Alias!(Node.getShape!(This,ijk)); 204 } 205 else static if(ijk.length == 0) 206 { 207 alias getShapes = Alias!""; 208 } 209 else 210 { 211 alias getShapes = Alias!(getShapes!(Node,This,ijk[0..1]) ~ "," ~ getShapes!(Node,This,ijk[1..$])); 212 } 213 } 214 template getExp(string This,Node,ulong N) 215 { 216 static if(is(Node == Tree!(Ridx,L,R,OP,Leafs),string Ridx,L,R,string OP,Leafs...)) 217 { 218 alias LHS = getExp!(This~"._lhs",L,N); 219 alias RHS = getExp!(This~"._rhs",R,LHS[1]); 220 static if(OP == "+" || OP =="-") 221 { 222 alias argsL = Alias!(getIndex!(L,Ridx)[0].map!"[a]".join(",").to!string); 223 alias argsR = Alias!(getIndex!(R,Ridx)[0].map!"[a]".join(",").to!string); 224 alias getExp = AliasSeq!(This~"._lhs.calc("~argsL~")" ~ OP ~ This~"._rhs.calc("~argsR~")" ,RHS[1]); 225 } 226 else 227 { 228 alias getExp = AliasSeq!("("~LHS[0]~")" ~ OP ~ "("~RHS[0]~")",RHS[1]); 229 } 230 } 231 else static if(is(Node == Leaf!(Ridx,idx,Ts),string Ridx,string idx,Ts)) 232 { 233 alias ijk = Alias!(idx.map!(c=>[c]).join(",").to!string); 234 alias getExp = AliasSeq!(This~".tensor["~ijk~"]",N+1); 235 } 236 else static if(is(Node == Func!(Ridx,fun,Leafs),string Ridx,alias fun,Leafs...)) 237 { 238 alias getExp = AliasSeq!(Node.asExp!(This),N+1); 239 } 240 else static if(is(Node == FnTensor!(idx,F),string idx,F)) 241 { 242 alias ijk = Alias!(idx.map!(c=>[c]).join(",").to!string); 243 alias getExp = AliasSeq!(This~".FUN("~ijk~")",N+1); 244 } 245 else 246 { 247 alias getExp = AliasSeq!(This,N+1); 248 } 249 } 250 template CommonTypeOfTensors(Leafs...) 251 { 252 import std.traits; 253 template getType(T) 254 { 255 static if(isBasicType!T) 256 alias getType = T; 257 else 258 alias getType = T.Type; 259 } 260 alias Types = staticMap!(getType,filterTensors!Leafs); 261 alias CommonTypeOfTensors = CommonType!(Types); 262 } 263 264 struct Tree(string ResultIdx ="",LHS,RHS,string op,Leafs...) 265 { 266 LHS _lhs; 267 RHS _rhs; 268 Leafs leafs; 269 alias LeafTypes = Leafs; 270 this(LHS lhs_,RHS rhs_,Leafs leafs_) 271 { 272 _lhs = lhs_; 273 _rhs = rhs_; 274 leafs = leafs_; 275 } 276 277 auto opBinary(string op,R)(R ohs) @nogc nothrow 278 { 279 static if(is(R == Tree!(Ridx,Lhs,Rhs,OP,aLeafs),string Ridx,Lhs,Rhs,string OP,aLeafs...)) 280 return Tree!("",typeof(this),R,op,Leafs,aLeafs)(this,ohs,leafs,ohs.leafs); 281 else 282 return Tree!("",typeof(this),R,op,Leafs,R)(this,ohs,leafs,ohs); 283 } 284 285 auto opBinaryRight(string op,R)(R ohs) @nogc nothrow 286 { 287 return Tree!("",R,typeof(this),op,R,Leafs)(ohs,this,ohs,leafs); 288 } 289 290 auto opUnary(string op)() @nogc nothrow 291 { 292 return Tree!("",Type,typeof(this),op,Type,Leafs)(0,this,0,leafs); 293 } 294 295 alias Type = CommonTypeOfTensors!(Leafs); 296 297 mixin(genCalc); 298 static auto genCalc() 299 { 300 alias indexes = getIndex!(typeof(this),ResultIdx); 301 string UNIQ = (indexes[0].map!(c=>"size_t "~[c]).join(",").to!string); 302 string DMMY = (indexes[1].map!(c=>[c]).join(",").to!string); 303 string result= "auto calc("~ UNIQ ~") @nogc nothrow{\n"; 304 static if(false && (op=="+"||op=="-") && (getIndex!(LHS,ResultIdx)[1].length > 0 || getIndex!(RHS,ResultIdx)[1].length > 0)) 305 { 306 string argsL = getIndex!(LHS,ResultIdx)[0].map!"[a]".join(",").to!string; 307 string argsR = getIndex!(RHS,ResultIdx)[0].map!"[a]".join(",").to!string; 308 result ~= "\t return _lhs.calc("~argsL~")" ~ op ~"_rhs.calc("~argsR~");\n"; 309 result ~="}\n"; 310 } 311 else 312 { 313 result ~= "\t size_t["~indexes[1].length.to!string~"] shape = [" ~ getShapes!(typeof(this),"this",indexes[1]) ~"];\n"; 314 result ~= "\t auto result = cast(Type)0;\n"; 315 if(DMMY.length > 0) 316 { 317 static foreach(i,ijk;indexes[1]) 318 { 319 result ~= "\tforeach("~ijk~";0..shape["~i.to!string~"])\n"; 320 } 321 } 322 result~= "\t\tresult += "~getExp!("this",typeof(this),0)[0]~";\n"; 323 result~= "\treturn result;\n"; 324 result~= "}\n"; 325 } 326 return result; 327 } 328 329 static auto getShape(string This = "this",string ijk)() 330 { 331 static foreach(i,Node;Leafs) 332 {{ 333 static if(__traits(hasMember,Node,"getShape")) 334 { 335 string result = Node.getShape!(This~".leafs["~i.to!string~"]",ijk); 336 } 337 else 338 { 339 string result = ""; 340 } 341 if(result != "") 342 { 343 return result; 344 } 345 }} 346 return ""; 347 } 348 349 } 350 351 352 353 struct Leaf(string ResultIdx ="",string idx,aTensor) 354 { 355 this(aTensor t) 356 { 357 tensor = t; 358 } 359 360 static auto getShape(string This = "this",string ijk)() 361 { 362 static if(idx.countUntil(ijk) >= 0) 363 { 364 return This ~ ".shape["~idx.countUntil(ijk).to!string ~"]"; 365 } 366 else 367 { 368 return ""; 369 } 370 } 371 372 373 374 aTensor tensor; 375 alias Type = aTensor.Type; 376 @nogc nothrow typeof(this)[1] leafs() => [this]; 377 alias LeafTypes = AliasSeq!(typeof(this)); 378 379 380 @nogc nothrow auto shape() => tensor.shape; 381 auto opBinary(string op,R)(R ohs) @nogc nothrow 382 { 383 static if(is(R == Tree!(Ridx,Lhs,Rhs,OP,Leafs),string Ridx,Lhs,Rhs,string OP,Leafs...)) 384 return Tree!("",typeof(this),R,op,typeof(this),Leafs)(this,ohs,this,ohs.leafs); 385 else 386 return Tree!("",typeof(this),R,op,typeof(this),R)(this,ohs,this,ohs); 387 } 388 auto opBinaryRight(string op,R)(R ohs) @nogc nothrow 389 { 390 static if(is(R == Tree!(Ridx,Lhs,Rhs,OP,Leafs),string Ridx,Lhs,Rhs,string OP,Leafs...)) 391 return Tree!("",R,typeof(this),op,Leafs,typeof(this))(this,ohs,ohs.leafs,this); 392 else 393 return Tree!("",R,typeof(this),op,R,typeof(this))(ohs,this,ohs,this); 394 } 395 auto opUnary(string op)() @nogc nothrow 396 { 397 return Tree!("",Type,typeof(this),op,typeof(this))(cast(Type)0,this,this); 398 } 399 mixin(genCalc); 400 static auto genCalc() 401 { 402 alias indexes = getIndex!(typeof(this),ResultIdx); 403 string UNIQ = (indexes[0].map!(c=>"size_t "~[c]).join(",").to!string); 404 string DMMY = (indexes[1].map!(c=>[c]).join(",").to!string); 405 string result="auto calc(" ~ UNIQ ~ ") @nogc nothrow {\n"; 406 result ~= "\t size_t["~indexes[1].length.to!string~"] shape = [" ~ getShapes!(typeof(this),"this",indexes[1]) ~"];\n"; 407 result ~= "\t auto result = cast(Type)0;\n"; 408 409 if(DMMY.length > 0) 410 { 411 static foreach(i,ijk;indexes[1]) 412 { 413 result ~= "\tforeach("~ijk~";0..shape["~i.to!string~"])\n"; 414 } 415 } 416 alias ijk = Alias!(idx.map!(c=>[c]).join(",").to!string); 417 result~= "\t\tresult += tensor["~ijk~"];\n"; 418 419 result~= "\treturn result;\n"; 420 result~= "}\n"; 421 return result; 422 } 423 424 } 425 426 427 struct Func(string ResultIdx ="",alias fun,Leafs...) 428 { 429 alias FUN = fun; 430 static if(is(ReturnType!fun)) 431 alias Type = ReturnType!fun; 432 else 433 alias Type = CommonTypeOfTensors!Leafs; 434 alias ArgLength = Alias!(Leafs.length); 435 Leafs leafs; 436 this(Leafs args_) 437 { 438 leafs = args_; 439 } 440 441 mixin(genCalc); 442 static auto genCalc() 443 { 444 alias indexes = getIndex!(typeof(this),ResultIdx); 445 string UNIQ = (indexes[0].map!(c=>"size_t "~[c]).join(",").to!string); 446 string DMMY = (indexes[1].map!(c=>[c]).join(",").to!string); 447 string result="auto calc(" ~ UNIQ ~ ") @nogc nothrow {\n"; 448 result ~= "\t size_t["~indexes[1].length.to!string~"] shape = [" ~ getShapes!(typeof(this),"this",indexes[1]) ~"];\n"; 449 result ~= "\t auto result = cast(Type)0;\n"; 450 451 if(DMMY.length > 0) 452 { 453 static foreach(i,ijk;indexes[1]) 454 { 455 result ~= "\tforeach("~ijk~";0..shape["~i.to!string~"])\n"; 456 } 457 } 458 result~= "\t\tresult += "~asExp~";\n"; 459 result~= "\treturn result;\n"; 460 result~= "}\n"; 461 return result; 462 } 463 static auto getShape(string This = "this",string ijk)() 464 { 465 static foreach(i,Node;Leafs) 466 {{ 467 static if(__traits(hasMember,Node,"getShape")) 468 { 469 auto result = Node.getShape!(This~".leafs["~i.to!string~"]",ijk); 470 } 471 else 472 { 473 auto result = ""; 474 } 475 if(result != "") 476 { 477 return result; 478 } 479 }} 480 return ""; 481 } 482 483 static auto asExp(string This = "this")() 484 { 485 string result = This~".FUN("; 486 static foreach(i,arg;leafs) 487 {{ 488 static if(is(Leafs[i] == Tree!(Ridx,LHS,RHS,op,Leafs_),string Ridx,LHS,RHS,string op,Leafs_...)) 489 { 490 result ~= getExp!(This~".leafs["~i.to!string~"]",Leafs[i],0)[0] ~ ","; 491 } 492 else static if(is(Leafs[i] == Leaf!(Ridx,idx,Tns),string Ridx,string idx,Tns)) 493 { 494 alias ijk = Alias!(idx.map!(c=>[c]).join(",").to!string); 495 result ~= This~".leafs["~i.to!string~"].tensor[" ~ijk~"],"; 496 } 497 else static if(is(Leafs[i] == Func!(Ridx,fun_,Leafs_),string Ridx,alias fun_,Leafs_...)) 498 { 499 result ~= Leafs[i] .asExp!("leafs["~i.to!string~"]")~","; 500 } 501 else static if(is(Leafs[i] == FnTensor!(idx,F),string idx,F)) 502 { 503 alias ijk = Alias!(idx.map!(c=>[c]).join(",").to!string); 504 result ~= Leafs[i]~".leafs["~i.to!string~"].FUN("~ijk~"),"; 505 } 506 else 507 { 508 result ~= This~".leafs["~i.to!string~"],"; 509 } 510 }} 511 result ~= ")"; 512 return result; 513 } 514 auto opBinary(string op,R)(R ohs) @nogc nothrow 515 { 516 static if(is(R == Tree!(Ridx,Lhs,Rhs,OP,Leafs_),string Ridx,Lhs,Rhs,string OP,Leafs_...)) 517 return Tree!("",typeof(this),R,op,typeof(this),Leafs_)(this,ohs,this,ohs.leafs); 518 else 519 return Tree!("",typeof(this),R,op,typeof(this),R)(this,ohs,this,ohs); 520 } 521 auto opBinaryRight(string op,R)(R ohs) @nogc nothrow 522 { 523 static if(is(R == Tree!(Ridx,Lhs,Rhs,OP,Leafs_),string Ridx,Lhs,Rhs,string OP,Leafs_...)) 524 return Tree!("",typeof(this),R,op,Leafs_,typeof(this))(this,ohs,ohs.leafs,this); 525 else 526 return Tree!("",typeof(this),R,op,R,typeof(this))(this,ohs,ohs,this); 527 } 528 } 529 530 struct FnTensor(string idx,F) 531 { 532 F FUN; 533 this(F)(F fun) if (isCallable!F) 534 { 535 FUN = fun; 536 } 537 538 static auto getShape(string This = "this",string ijk)() 539 { 540 return ""; 541 } 542 auto opBinary(string op,R)(R ohs) @nogc nothrow 543 { 544 static if(is(R == Tree!(Ridx,Lhs,Rhs,OP,Leafs_),string Ridx,Lhs,Rhs,string OP,Leafs_...)) 545 return Tree!("",typeof(this),R,op,typeof(this),Leafs_)(this,ohs,this,ohs.leafs); 546 else 547 return Tree!("",typeof(this),R,op,typeof(this),R)(this,ohs,this,ohs); 548 } 549 auto opBinaryRight(string op,R)(R ohs) @nogc nothrow 550 { 551 static if(is(R == Tree!(Ridx,Lhs,Rhs,OP,Leafs_),string Ridx,Lhs,Rhs,string OP,Leafs_...)) 552 return Tree!("",typeof(this),R,op,Leafs_,typeof(this))(this,ohs,ohs.leafs,this); 553 else 554 return Tree!("",typeof(this),R,op,R,typeof(this))(this,ohs,ohs,this); 555 } 556 } 557 558 private auto calc(R)(R value) 559 { 560 return value; 561 }