jupyterjazz commited on
Commit
8b2ad1e
1 Parent(s): a6bb16f

fix: update frequencies when updating the rope base value (#40)

Browse files

- fix: update frequencies when updating the rope base value (d8cbc92c8650d6bdc8e5afb28785625a98ccfab1)
- Update rotary.py (90873c4a21ac932b2df31d0e35e56b9c55460470)
- Update rotary.py (071760a5bbecc7b738c64583a3b5b337cd6d0667)
- Update rotary.py (1eb2361d4e9bdeedc1516196f02f199515916d30)
- Update rotary.py (066b97bdf39f4031bf1ddee4c706d5c842fb8748)

Files changed (1) hide show
  1. rotary.py +17 -2
rotary.py CHANGED
@@ -493,8 +493,16 @@ class RotaryEmbedding(torch.nn.Module):
493
 
494
  @base.setter
495
  def base(self, new_base):
 
496
  if new_base > 0:
497
- self._base = float(new_base)
 
 
 
 
 
 
 
498
  else:
499
  raise ValueError("Rotary base value must be positive")
500
 
@@ -507,21 +515,27 @@ class RotaryEmbedding(torch.nn.Module):
507
  )
508
  )
509
 
510
- def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
 
 
511
  # Reset the tables if the sequence length has changed,
512
  # if we're on a new device (possibly due to tracing for instance),
513
  # or if we're switching from inference mode to training
 
514
  if (
515
  seqlen > self._seq_len_cached
516
  or self._cos_cached is None
517
  or self._cos_cached.device != device
518
  or self._cos_cached.dtype != dtype
519
  or (self.training and self._cos_cached.is_inference())
 
520
  ):
521
  self._seq_len_cached = seqlen
522
  # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
523
  # And the output of arange can be quite large, so bf16 would lose a lot of precision.
524
  # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
 
 
525
  if self.pos_idx_in_fp32:
526
  t = torch.arange(seqlen, device=device, dtype=torch.float32)
527
  # We want fp32 here as well since inv_freq will be multiplied with t, and the output
@@ -535,6 +549,7 @@ class RotaryEmbedding(torch.nn.Module):
535
  else:
536
  t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
537
  inv_freq = self.inv_freq
 
538
  # Don't do einsum, it converts fp32 to fp16 under AMP
539
  # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
540
  freqs = torch.outer(t, inv_freq)
 
493
 
494
  @base.setter
495
  def base(self, new_base):
496
+ new_base = float(new_base)
497
  if new_base > 0:
498
+ if self._base != new_base: # only update if the base value has changed
499
+ self._base = new_base
500
+ self._update_cos_sin_cache(
501
+ self._seq_len_cached,
502
+ device=self.inv_freq.device,
503
+ dtype=self._cos_cached.dtype if self._cos_cached is not None else None,
504
+ rotary_base_changed=True,
505
+ )
506
  else:
507
  raise ValueError("Rotary base value must be positive")
508
 
 
515
  )
516
  )
517
 
518
+ def _update_cos_sin_cache(
519
+ self, seqlen, device=None, dtype=None, rotary_base_changed=False
520
+ ):
521
  # Reset the tables if the sequence length has changed,
522
  # if we're on a new device (possibly due to tracing for instance),
523
  # or if we're switching from inference mode to training
524
+ # or if the rotary base value was changed
525
  if (
526
  seqlen > self._seq_len_cached
527
  or self._cos_cached is None
528
  or self._cos_cached.device != device
529
  or self._cos_cached.dtype != dtype
530
  or (self.training and self._cos_cached.is_inference())
531
+ or rotary_base_changed
532
  ):
533
  self._seq_len_cached = seqlen
534
  # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
535
  # And the output of arange can be quite large, so bf16 would lose a lot of precision.
536
  # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
537
+ if rotary_base_changed:
538
+ self.inv_freq = self._compute_inv_freq(device=device)
539
  if self.pos_idx_in_fp32:
540
  t = torch.arange(seqlen, device=device, dtype=torch.float32)
541
  # We want fp32 here as well since inv_freq will be multiplied with t, and the output
 
549
  else:
550
  t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
551
  inv_freq = self.inv_freq
552
+
553
  # Don't do einsum, it converts fp32 to fp16 under AMP
554
  # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
555
  freqs = torch.outer(t, inv_freq)